Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bindings/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@

---

# Changes in Version 1.9.0 (2025/XX/XX)

- Providing a schema now enforces strict type adherence for data.
If a result contains a field whose value does not match the schema's type for that field, a TypeError will be raised.
Note that ``NaN`` is a valid type for all fields.
To suppress these errors and instead silently convert such mismatches to ``NaN``, pass the ``allow_invalid=True`` argument to your ``pymongoarrow`` API call.
For example, a result with a field of type ``int`` but with a string value will now raise a TypeError,
unless ``allow_invalid=True`` is passed, in which case the result's field will have a value of ``NaN``.

# Changes in Version 1.8.0 (2025/05/12)

- Add support for PyArrow 20.0.
Expand Down
70 changes: 54 additions & 16 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)


def find_arrow_all(collection, query, *, schema=None, **kwargs):
def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`pyarrow.Table` instance.

Expand All @@ -83,14 +83,18 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.

:Returns:
An instance of class:`pyarrow.Table`.
"""
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
context = PyMongoArrowContext(
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
)

for opt in ("cursor_type",):
if kwargs.pop(opt, None):
Expand All @@ -110,7 +114,7 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
return context.finish()


def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`pyarrow.Table` instance.

Expand All @@ -121,14 +125,18 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.

:Returns:
An instance of class:`pyarrow.Table`.
"""
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
context = PyMongoArrowContext(
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
)

if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
msg = (
Expand Down Expand Up @@ -165,7 +173,7 @@ def _arrow_to_pandas(arrow_table):
return arrow_table.to_pandas(split_blocks=True, self_destruct=True)


def find_pandas_all(collection, query, *, schema=None, **kwargs):
def find_pandas_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`pandas.DataFrame` instance.

Expand All @@ -176,17 +184,21 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.

:Returns:
An instance of class:`pandas.DataFrame`.
"""
return _arrow_to_pandas(find_arrow_all(collection, query, schema=schema, **kwargs))
return _arrow_to_pandas(
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs)
)


def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_pandas_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`pandas.DataFrame` instance.

Expand All @@ -197,14 +209,20 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.

:Returns:
An instance of class:`pandas.DataFrame`.
"""
return _arrow_to_pandas(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
return _arrow_to_pandas(
aggregate_arrow_all(
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
)
)


def _arrow_to_numpy(arrow_table, schema=None):
Expand All @@ -227,7 +245,7 @@ def _arrow_to_numpy(arrow_table, schema=None):
return container


def find_numpy_all(collection, query, *, schema=None, **kwargs):
def find_numpy_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`dict` instance whose keys are field names and values are
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
Expand All @@ -239,6 +257,8 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand All @@ -255,10 +275,13 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
:Returns:
An instance of :class:`dict`.
"""
return _arrow_to_numpy(find_arrow_all(collection, query, schema=schema, **kwargs), schema)
return _arrow_to_numpy(
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs),
schema,
)


def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_numpy_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`dict` instance whose keys are field names and values are
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
Expand All @@ -270,6 +293,8 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.
Expand All @@ -287,7 +312,10 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
An instance of :class:`dict`.
"""
return _arrow_to_numpy(
aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs), schema
aggregate_arrow_all(
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
),
schema,
)


Expand Down Expand Up @@ -326,7 +354,7 @@ def _arrow_to_polars(arrow_table: pa.Table):
return pl.from_arrow(arrow_table_without_extensions)


def find_polars_all(collection, query, *, schema=None, **kwargs):
def find_polars_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`polars.DataFrame` instance.

Expand All @@ -337,6 +365,8 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand All @@ -346,10 +376,12 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):

.. versionadded:: 1.3
"""
return _arrow_to_polars(find_arrow_all(collection, query, schema=schema, **kwargs))
return _arrow_to_polars(
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs)
)


def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_polars_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`polars.DataFrame` instance.

Expand All @@ -360,14 +392,20 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.

:Returns:
An instance of class:`polars.DataFrame`.
"""
return _arrow_to_polars(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
return _arrow_to_polars(
aggregate_arrow_all(
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
)
)


def _transform_bwe(bwe, offset):
Expand Down
8 changes: 6 additions & 2 deletions bindings/python/pymongoarrow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
class PyMongoArrowContext:
"""A context for converting BSON-formatted data to an Arrow Table."""

def __init__(self, schema, codec_options=None):
def __init__(self, schema, codec_options=None, allow_invalid=False):
"""Initialize the context.

:Parameters:
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
- `builder_map`: Mapping of utf-8-encoded field names to
:class:`~pymongoarrow.builders._BuilderBase` instances.
- `allow_invalid`: If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.
"""
self.schema = schema
if self.schema is None and codec_options is not None:
Expand All @@ -40,7 +42,9 @@ def __init__(self, schema, codec_options=None):
# Delayed import to prevent import errors for unbuilt library.
from pymongoarrow.lib import BuilderManager

self.manager = BuilderManager(schema_map, self.schema is not None, self.tzinfo)
self.manager = BuilderManager(
schema_map, self.schema is not None, self.tzinfo, allow_invalid=allow_invalid
)
self.schema_map = schema_map

def process_bson_stream(self, stream):
Expand Down
Loading
Loading