Skip to content

Commit cbe4196

Browse files
committed
add optional convert param to write
1 parent 6c46b99 commit cbe4196

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,20 @@
3535
from numpy import ndarray
3636
from pyarrow import Schema as ArrowSchema
3737
from pyarrow import Table, timestamp
38-
from pyarrow.types import is_date32, is_date64
38+
from pyarrow.types import (
39+
is_date32,
40+
is_date64,
41+
is_duration,
42+
is_float16,
43+
is_float32,
44+
is_int8,
45+
is_int16,
46+
is_list,
47+
is_uint8,
48+
is_uint16,
49+
is_uint32,
50+
is_uint64,
51+
)
3952
from pymongo.common import MAX_WRITE_BATCH_SIZE
4053

4154
from pymongoarrow.context import PyMongoArrowContext
@@ -475,14 +488,15 @@ def transform_python(self, value):
475488
return Decimal128(value)
476489

477490

478-
def write(collection, tabular, *, exclude_none: bool = False):
491+
def write(collection, tabular, *, exclude_none: bool = False, auto_convert: bool = True):
479492
"""Write data from `tabular` into the given MongoDB `collection`.
480493
481494
:Parameters:
482495
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
483496
against which to run the operation.
484497
- `tabular`: A tabular data store to use for the write operation.
485498
- `exclude_none`: Whether to skip writing `null` fields in documents.
499+
- `auto_convert` (optional): Whether to attempt a best-effort conversion of unsupported types.
486500
487501
:Returns:
488502
An instance of :class:`result.ArrowWriteResult`.
@@ -500,9 +514,24 @@ def write(collection, tabular, *, exclude_none: bool = False):
500514
if is_date32(dtype) or is_date64(dtype):
501515
changed = True
502516
dtype = timestamp("ms") # noqa: PLW2901
517+
elif auto_convert:
518+
if is_uint8(dtype) or is_uint16(dtype) or is_int8(dtype) or is_int16(dtype):
519+
changed = True
520+
dtype = pa.int32() # noqa: PLW2901
521+
elif is_uint32(dtype) or is_uint64(dtype) or is_duration(dtype):
522+
changed = True
523+
dtype = pa.int64() # noqa: PLW2901
524+
elif is_float16(dtype) or is_float32(dtype):
525+
changed = True
526+
dtype = pa.float64() # noqa: PLW2901
503527
new_types.append(dtype)
504528
if changed:
505-
cols = [tabular.column(i).cast(new_types[i]) for i in range(tabular.num_columns)]
529+
cols = [
530+
tabular.column(i).cast(new_types[i])
531+
if not is_list(new_types[i])
532+
else tabular.column(i)
533+
for i in range(tabular.num_columns)
534+
]
506535
tabular = Table.from_arrays(cols, names=tabular.column_names)
507536
_validate_schema(tabular.schema.types)
508537
elif pd is not None and isinstance(tabular, pd.DataFrame):

bindings/python/test/test_arrow.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import threading
2020
import unittest
2121
import unittest.mock as mock
22-
from datetime import date, datetime
22+
from datetime import date, datetime, timedelta
2323
from pathlib import Path
2424
from test import client_context
2525
from test.utils import AllowListEventListener, NullsTestMixin
@@ -1095,22 +1095,22 @@ def alltypes_sample(self, size=10000, seed=0, categorical=False):
10951095

10961096
np.random.seed(seed)
10971097
arrays = {
1098-
"uint8": pa.array(np.arange(size, dtype=np.uint8), type=pa.int32()),
1099-
"uint16": pa.array(np.arange(size, dtype=np.uint16), type=pa.int32()),
1100-
"uint32": pa.array(np.arange(size, dtype=np.uint32), type=pa.int64()),
1101-
"uint64": pa.array(np.arange(size, dtype=np.uint64), type=pa.int64()),
1102-
"int8": pa.array(np.arange(size, dtype=np.int8), type=pa.int32()),
1103-
"int16": pa.array(np.arange(size, dtype=np.int16), type=pa.int32()),
1098+
"uint8": np.arange(size, dtype=np.uint8),
1099+
"uint16": np.arange(size, dtype=np.uint16),
1100+
"uint32": np.arange(size, dtype=np.uint32),
1101+
"uint64": np.arange(size, dtype=np.uint64),
1102+
"int8": np.arange(size, dtype=np.int8),
1103+
"int16": np.arange(size, dtype=np.int16),
11041104
"int32": np.arange(size, dtype=np.int32),
11051105
"int64": np.arange(size, dtype=np.int64),
1106-
"float16": pa.array(np.arange(size, dtype=np.float16), type=pa.float64()),
1107-
"float32": pa.array(np.arange(size, dtype=np.float32), type=pa.float64()),
1108-
"float64": pa.array(np.arange(size, dtype=np.float64), type=pa.float64()),
1106+
"float16": np.arange(size, dtype=np.float16),
1107+
"float32": np.arange(size, dtype=np.float32),
1108+
"float64": np.arange(size, dtype=np.float64),
11091109
"bool": np.random.randn(size) > 0,
11101110
"datetime_ms": np.arange("2016-01-01T00:00:00.001", size, dtype="datetime64[ms]"),
11111111
"datetime_us": np.arange("2016-01-01T00:00:00.000001", size, dtype="datetime64[us]"),
11121112
"datetime_ns": np.arange("2016-01-01T00:00:00.000000001", size, dtype="datetime64[ns]"),
1113-
"timedelta": pa.array(np.arange(size, dtype="timedelta64[s]"), type=pa.int64()),
1113+
"timedelta": np.arange(size, dtype="timedelta64[s]"),
11141114
"str": pd.Series([str(x) for x in range(size)]),
11151115
"empty_str": [""] * size,
11161116
"str_with_nulls": [None] + [str(x) for x in range(size - 2)] + [None],
@@ -1181,6 +1181,10 @@ def compare_arrow_mongodb_data(self, arrow_table, mongo_data):
11811181
assert (
11821182
arrow_value == mongo_value
11831183
), f"List mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
1184+
elif isinstance(arrow_value, timedelta):
1185+
assert (
1186+
arrow_value == timedelta(seconds=mongo_value)
1187+
), f"Timedelta mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
11841188
else:
11851189
assert (
11861190
arrow_value == mongo_value

0 commit comments

Comments
 (0)