Skip to content

Commit a5b6291

Browse files
authored
ARROW-86 Basic auto-discovery of schemas (#88)
1 parent 225a9e4 commit a5b6291

File tree

4 files changed

+110
-16
lines changed

4 files changed

+110
-16
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@
5858
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)
5959

6060

61-
def find_arrow_all(collection, query, *, schema, **kwargs):
61+
def find_arrow_all(collection, query, *, schema=None, **kwargs):
6262
"""Method that returns the results of a find query as a
6363
:class:`pyarrow.Table` instance.
6464
6565
:Parameters:
6666
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
6767
against which to run the ``find`` operation.
6868
- `query`: A mapping containing the query to use for the find operation.
69-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
69+
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
7070
7171
Additional keyword-arguments passed to this method will be passed
7272
directly to the underlying ``find`` operation.
@@ -84,23 +84,25 @@ def find_arrow_all(collection, query, *, schema, **kwargs):
8484
stacklevel=2,
8585
)
8686

87-
kwargs.setdefault("projection", schema._get_projection())
87+
if schema:
88+
kwargs.setdefault("projection", schema._get_projection())
89+
8890
raw_batch_cursor = collection.find_raw_batches(query, **kwargs)
8991
for batch in raw_batch_cursor:
9092
process_bson_stream(batch, context)
9193

9294
return context.finish()
9395

9496

95-
def aggregate_arrow_all(collection, pipeline, *, schema, **kwargs):
97+
def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
9698
"""Method that returns the results of an aggregation pipeline as a
9799
:class:`pyarrow.Table` instance.
98100
99101
:Parameters:
100102
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
101103
against which to run the ``aggregate`` operation.
102104
- `pipeline`: A list of aggregation pipeline stages.
103-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
105+
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
104106
105107
Additional keyword-arguments passed to this method will be passed
106108
directly to the underlying ``aggregate`` operation.
@@ -143,15 +145,15 @@ def _arrow_to_pandas(arrow_table):
143145
return arrow_table.to_pandas(split_blocks=True, self_destruct=True)
144146

145147

146-
def find_pandas_all(collection, query, *, schema, **kwargs):
148+
def find_pandas_all(collection, query, *, schema=None, **kwargs):
147149
"""Method that returns the results of a find query as a
148150
:class:`pandas.DataFrame` instance.
149151
150152
:Parameters:
151153
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
152154
against which to run the ``find`` operation.
153155
- `query`: A mapping containing the query to use for the find operation.
154-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
156+
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
155157
156158
Additional keyword-arguments passed to this method will be passed
157159
directly to the underlying ``find`` operation.
@@ -162,15 +164,15 @@ def find_pandas_all(collection, query, *, schema, **kwargs):
162164
return _arrow_to_pandas(find_arrow_all(collection, query, schema=schema, **kwargs))
163165

164166

165-
def aggregate_pandas_all(collection, pipeline, *, schema, **kwargs):
167+
def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
166168
"""Method that returns the results of an aggregation pipeline as a
167169
:class:`pandas.DataFrame` instance.
168170
169171
:Parameters:
170172
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
171173
against which to run the ``find`` operation.
172174
- `pipeline`: A list of aggregation pipeline stages.
173-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
175+
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
174176
175177
Additional keyword-arguments passed to this method will be passed
176178
directly to the underlying ``aggregate`` operation.
@@ -181,7 +183,7 @@ def aggregate_pandas_all(collection, pipeline, *, schema, **kwargs):
181183
return _arrow_to_pandas(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
182184

183185

184-
def _arrow_to_numpy(arrow_table, schema):
186+
def _arrow_to_numpy(arrow_table, schema=None):
185187
"""Helper function that converts an Arrow Table to a dictionary
186188
containing NumPy arrays. The memory buffers backing the given Arrow Table
187189
may be destroyed after conversion if the resulting Numpy array(s) is not a
@@ -190,6 +192,9 @@ def _arrow_to_numpy(arrow_table, schema):
190192
See https://arrow.apache.org/docs/python/numpy.html for details.
191193
"""
192194
container = {}
195+
if not schema:
196+
schema = arrow_table.schema
197+
193198
for fname in schema:
194199
dtype = get_numpy_type(schema.typemap[fname])
195200
if dtype == np.str_:
@@ -199,7 +204,7 @@ def _arrow_to_numpy(arrow_table, schema):
199204
return container
200205

201206

202-
def find_numpy_all(collection, query, *, schema, **kwargs):
207+
def find_numpy_all(collection, query, *, schema=None, **kwargs):
203208
"""Method that returns the results of a find query as a
204209
:class:`dict` instance whose keys are field names and values are
205210
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
@@ -208,7 +213,7 @@ def find_numpy_all(collection, query, *, schema, **kwargs):
208213
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
209214
against which to run the ``find`` operation.
210215
- `query`: A mapping containing the query to use for the find operation.
211-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
216+
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
212217
213218
Additional keyword-arguments passed to this method will be passed
214219
directly to the underlying ``find`` operation.
@@ -228,7 +233,7 @@ def find_numpy_all(collection, query, *, schema, **kwargs):
228233
return _arrow_to_numpy(find_arrow_all(collection, query, schema=schema, **kwargs), schema)
229234

230235

231-
def aggregate_numpy_all(collection, pipeline, *, schema, **kwargs):
236+
def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
232237
"""Method that returns the results of an aggregation pipeline as a
233238
:class:`dict` instance whose keys are field names and values are
234239
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
@@ -237,7 +242,7 @@ def aggregate_numpy_all(collection, pipeline, *, schema, **kwargs):
237242
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
238243
against which to run the ``find`` operation.
239244
- `query`: A mapping containing the query to use for the find operation.
240-
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
245+
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
241246
242247
Additional keyword-arguments passed to this method will be passed
243248
directly to the underlying ``aggregate`` operation.

bindings/python/pymongoarrow/context.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
class PyMongoArrowContext:
4040
"""A context for converting BSON-formatted data to an Arrow Table."""
4141

42-
def __init__(self, schema, builder_map):
42+
def __init__(self, schema, builder_map, codec_options=None):
4343
"""Initialize the context.
4444
4545
:Parameters:
@@ -49,6 +49,10 @@ def __init__(self, schema, builder_map):
4949
"""
5050
self.schema = schema
5151
self.builder_map = builder_map
52+
if self.schema is None and codec_options is not None:
53+
self.tzinfo = codec_options.tzinfo
54+
else:
55+
self.tzinfo = None
5256

5357
@classmethod
5458
def from_schema(cls, schema, codec_options=DEFAULT_CODEC_OPTIONS):
@@ -60,6 +64,9 @@ def from_schema(cls, schema, codec_options=DEFAULT_CODEC_OPTIONS):
6064
- `codec_options` (optional): An instance of
6165
:class:`~bson.codec_options.CodecOptions`.
6266
"""
67+
if schema is None:
68+
return cls(schema, {})
69+
6370
builder_map = {}
6471
str_type_map = _get_internal_typemap(schema.typemap)
6572
for fname, ftype in str_type_map.items():

bindings/python/pymongoarrow/lib.pyx

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N
5353
raise InvalidBSON("Could not read BSON document stream")
5454
return doc
5555

56+
_builder_type_map = {
57+
0x10: Int32Builder,
58+
0x12: Int64Builder,
59+
0x01: DoubleBuilder,
60+
0x09: DatetimeBuilder,
61+
0x07: ObjectIdBuilder,
62+
0x02: StringBuilder,
63+
0x08: BoolBuilder
64+
}
5665

5766
def process_bson_stream(bson_stream, context):
5867
cdef const uint8_t* docstream = <const uint8_t *>bson_stream
@@ -94,6 +103,21 @@ def process_bson_stream(bson_stream, context):
94103
while bson_iter_next(&doc_iter):
95104
key = bson_iter_key(&doc_iter)
96105
builder = builder_map.get(key)
106+
if builder is None and context.schema is None:
107+
# Only run if there is no schema.
108+
ftype = bson_iter_type(&doc_iter)
109+
if ftype not in _builder_type_map:
110+
continue
111+
112+
builder_type = _builder_type_map[ftype]
113+
if builder_type == DatetimeBuilder and context.tzinfo is not None:
114+
arrow_type = timestamp(arrow_type.unit, tz=context.tzinfo)
115+
builder_map[key] = builder_type(dtype=arrow_type)
116+
else:
117+
builder_map[key] = builder_type()
118+
builder = builder_map[key]
119+
for _ in range(count):
120+
builder.append_null()
97121
if builder is not None:
98122
ftype = builder.type_marker
99123
value_t = bson_iter_type(&doc_iter)

bindings/python/test/test_arrow.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
import os
1515
import unittest
1616
import unittest.mock as mock
17+
from datetime import datetime
1718
from test import client_context
1819
from test.utils import AllowListEventListener, TestNullsBase
1920

2021
import pyarrow
2122
import pymongo
22-
from bson import Decimal128, ObjectId
23+
from bson import CodecOptions, Decimal128, ObjectId
2324
from pyarrow import Table, binary, bool_, decimal256, float64, int32, int64
2425
from pyarrow import schema as ArrowSchema
2526
from pyarrow import string, timestamp
@@ -34,6 +35,7 @@
3435
Decimal128StringType,
3536
ObjectIdType,
3637
)
38+
from pytz import timezone
3739

3840

3941
class TestArrowApiMixin:
@@ -326,6 +328,62 @@ def test_string_bool(self):
326328
),
327329
)
328330

331+
def test_auto_schema(self):
332+
# Create table with random data of various types.
333+
data = Table.from_pydict(
334+
{
335+
"string": [None] + [str(i) for i in range(2)],
336+
"bool": [True for _ in range(3)],
337+
"dt": [datetime(1970 + i, 1, 1) for i in range(3)],
338+
},
339+
ArrowSchema(
340+
{
341+
"bool": bool_(),
342+
"dt": timestamp("ms"),
343+
"string": string(),
344+
}
345+
),
346+
)
347+
348+
self.coll.drop()
349+
res = write(self.coll, data)
350+
self.assertEqual(len(data), res.raw_result["insertedCount"])
351+
out = find_arrow_all(self.coll, {}).drop(["_id"])
352+
self.assertEqual(data, out)
353+
354+
def test_auto_schema_heterogeneous(self):
355+
vals = [1, "2", True, 4]
356+
data = [{"a": v} for v in vals]
357+
358+
self.coll.drop()
359+
self.coll.insert_many(data)
360+
out = find_arrow_all(self.coll, {}).drop(["_id"])
361+
self.assertEqual(out["a"].to_pylist(), [1, None, None, 4])
362+
363+
def test_auto_schema_tz(self):
364+
# Create table with random data of various types.
365+
data = Table.from_pydict(
366+
{
367+
"bool": [True for _ in range(3)],
368+
"dt": [datetime(1970 + i, 1, 1, tzinfo=timezone("US/Eastern")) for i in range(3)],
369+
"string": [None] + [str(i) for i in range(2)],
370+
},
371+
ArrowSchema(
372+
{
373+
"bool": bool_(),
374+
"dt": timestamp("ms"),
375+
"string": string(),
376+
}
377+
),
378+
)
379+
380+
self.coll.drop()
381+
codec_options = CodecOptions(tzinfo=timezone("US/Eastern"), tz_aware=True)
382+
res = write(self.coll.with_options(codec_options=codec_options), data)
383+
self.assertEqual(len(data), res.raw_result["insertedCount"])
384+
out = find_arrow_all(self.coll.with_options(codec_options=codec_options), {}).drop(["_id"])
385+
self.assertEqual(data, out)
386+
329387

330388
class TestArrowExplicitApi(TestArrowApiMixin, unittest.TestCase):
331389
def run_find(self, *args, **kwargs):

0 commit comments

Comments
 (0)