Skip to content

Commit 9ec5ca8

Browse files
authored
ARROW-117 "process_bson_stream" Does Not Construct Projection With Auto-Discovered Schema (#92)
1 parent d1bb5eb commit 9ec5ca8

File tree

4 files changed

+141
-11
lines changed

4 files changed

+141
-11
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
129129
UserWarning,
130130
stacklevel=2,
131131
)
132+
if schema:
133+
pipeline.append({"$project": schema._get_projection()})
132134

133-
pipeline.append({"$project": schema._get_projection()})
134135
raw_batch_cursor = collection.aggregate_raw_batches(pipeline, **kwargs)
135136
for batch in raw_batch_cursor:
136137
process_bson_stream(batch, context)
@@ -201,10 +202,12 @@ def _arrow_to_numpy(arrow_table, schema=None):
201202
"""
202203
container = {}
203204
if not schema:
204-
schema = arrow_table.schema
205+
schema = {i.name: i.type for i in arrow_table.schema}
206+
else:
207+
schema = schema.typemap
205208

206209
for fname in schema:
207-
dtype = get_numpy_type(schema.typemap[fname])
210+
dtype = get_numpy_type(schema[fname])
208211
if dtype == np.str_:
209212
container[fname] = arrow_table[fname].to_pandas().to_numpy(dtype=dtype)
210213
else:

bindings/python/test/test_arrow.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,17 +348,19 @@ def test_auto_schema(self):
348348
self.coll.drop()
349349
res = write(self.coll, data)
350350
self.assertEqual(len(data), res.raw_result["insertedCount"])
351-
out = find_arrow_all(self.coll, {}).drop(["_id"])
352-
self.assertEqual(data, out)
351+
for func in [find_arrow_all, aggregate_arrow_all]:
352+
out = func(self.coll, {} if func == find_arrow_all else []).drop(["_id"])
353+
self.assertEqual(data, out)
353354

354355
def test_auto_schema_heterogeneous(self):
355356
vals = [1, "2", True, 4]
356357
data = [{"a": v} for v in vals]
357358

358359
self.coll.drop()
359360
self.coll.insert_many(data)
360-
out = find_arrow_all(self.coll, {}).drop(["_id"])
361-
self.assertEqual(out["a"].to_pylist(), [1, None, None, 4])
361+
for func in [find_arrow_all, aggregate_arrow_all]:
362+
out = func(self.coll, {} if func == find_arrow_all else []).drop(["_id"])
363+
self.assertEqual(out["a"].to_pylist(), [1, None, None, 4])
362364

363365
def test_auto_schema_tz(self):
364366
# Create table with random data of various types.
@@ -381,8 +383,12 @@ def test_auto_schema_tz(self):
381383
codec_options = CodecOptions(tzinfo=timezone("US/Eastern"), tz_aware=True)
382384
res = write(self.coll.with_options(codec_options=codec_options), data)
383385
self.assertEqual(len(data), res.raw_result["insertedCount"])
384-
out = find_arrow_all(self.coll.with_options(codec_options=codec_options), {}).drop(["_id"])
385-
self.assertEqual(data, out)
386+
for func in [find_arrow_all, aggregate_arrow_all]:
387+
out = func(
388+
self.coll.with_options(codec_options=codec_options),
389+
{} if func == find_arrow_all else [],
390+
).drop(["_id"])
391+
self.assertEqual(data, out)
386392

387393

388394
class TestArrowExplicitApi(TestArrowApiMixin, unittest.TestCase):

bindings/python/test/test_numpy.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from unittest import mock
2020

2121
import numpy as np
22-
from bson import Decimal128, ObjectId
22+
from bson import CodecOptions, Decimal128, ObjectId
2323
from pyarrow import int32, int64
2424
from pymongo import DESCENDING, WriteConcern
2525
from pymongo.collection import Collection
@@ -30,6 +30,7 @@
3030
Decimal128StringType,
3131
ObjectIdType,
3232
)
33+
from pytz import timezone
3334

3435

3536
class NumpyTestBase(unittest.TestCase):
@@ -210,6 +211,67 @@ def test_string_bool(self):
210211
),
211212
)
212213

214+
def test_auto_schema(self):
215+
schema = {
216+
"bool": "bool",
217+
"dt": "datetime64[ms]",
218+
"string": "str",
219+
}
220+
data = {
221+
"string": [None] + [str(i) for i in range(2)],
222+
"bool": [True for _ in range(3)],
223+
"dt": [datetime.datetime(1970 + i, 1, 1) for i in range(3)],
224+
}
225+
data = self.schemafied_ndarray_dict(data, schema)
226+
227+
self.coll.drop()
228+
res = write(self.coll, data)
229+
self.assertEqual(len(data), res.raw_result["insertedCount"])
230+
for func in [find_numpy_all, aggregate_numpy_all]:
231+
with self.subTest(func.__name__):
232+
out = func(self.coll, {} if func == find_numpy_all else [])
233+
del out["_id"]
234+
self.assert_numpy_equal(data, out)
235+
236+
def test_auto_schema_heterogeneous(self):
237+
vals = [1, "2", True, 4]
238+
data = [{"a": v} for v in vals]
239+
240+
self.coll.drop()
241+
self.coll.insert_many(data)
242+
for func in [find_numpy_all, aggregate_numpy_all]:
243+
with self.subTest(func.__name__):
244+
out = func(self.coll, {} if func == find_numpy_all else [])
245+
del out["_id"]
246+
np.equal(out, [1.0, np.nan, np.nan, 4.0])
247+
248+
def test_auto_schema_tz(self):
249+
schema = {
250+
"bool": "bool",
251+
"dt": "datetime64[ms]",
252+
"string": "str",
253+
}
254+
data = {
255+
"string": [str(i) for i in range(3)],
256+
"bool": [True for _ in range(3)],
257+
"dt": [
258+
datetime.datetime(1970 + i, 1, 1, tzinfo=timezone("US/Eastern")) for i in range(3)
259+
],
260+
}
261+
data = self.schemafied_ndarray_dict(data, schema)
262+
self.coll.drop()
263+
codec_options = CodecOptions(tzinfo=timezone("US/Eastern"), tz_aware=True)
264+
res = write(self.coll.with_options(codec_options=codec_options), data)
265+
self.assertEqual(len(data), res.raw_result["insertedCount"])
266+
for func in [find_numpy_all, aggregate_numpy_all]:
267+
with self.subTest(func.__name__):
268+
out = func(
269+
self.coll.with_options(codec_options=codec_options),
270+
{} if func == find_numpy_all else [],
271+
)
272+
del out["_id"]
273+
self.assert_numpy_equal(data, out)
274+
213275

214276
class TestBSONTypes(NumpyTestBase):
215277
@classmethod

bindings/python/test/test_pandas.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pandas as pd
2323
import pandas.testing
2424
import pyarrow
25-
from bson import Decimal128, ObjectId
25+
from bson import CodecOptions, Decimal128, ObjectId
2626
from pyarrow import decimal256, int32, int64
2727
from pymongo import DESCENDING, WriteConcern
2828
from pymongo.collection import Collection
@@ -33,6 +33,7 @@
3333
Decimal128StringType,
3434
ObjectIdType,
3535
)
36+
from pytz import timezone
3637

3738

3839
class PandasTestBase(unittest.TestCase):
@@ -188,6 +189,64 @@ def test_string_bool(self):
188189
),
189190
)
190191

192+
def test_auto_schema(self):
193+
schema = {
194+
"bool": "bool",
195+
"dt": "datetime64[ns]",
196+
"string": "str",
197+
}
198+
data = pd.DataFrame(
199+
data={
200+
"string": [None] + [str(i) for i in range(2)],
201+
"bool": [True for _ in range(3)],
202+
"dt": [datetime.datetime(1970 + i, 1, 1) for i in range(3)],
203+
},
204+
).astype(schema)
205+
206+
self.coll.drop()
207+
res = write(self.coll, data)
208+
self.assertEqual(len(data), res.raw_result["insertedCount"])
209+
for func in [find_pandas_all, aggregate_pandas_all]:
210+
out = func(self.coll, {} if func == find_pandas_all else []).drop(columns=["_id"])
211+
pd.testing.assert_frame_equal(data, out)
212+
213+
def test_auto_schema_heterogeneous(self):
214+
vals = [1, "2", True, 4]
215+
data = [{"a": v} for v in vals]
216+
217+
self.coll.drop()
218+
self.coll.insert_many(data)
219+
for func in [find_pandas_all, aggregate_pandas_all]:
220+
out = func(self.coll, {} if func == find_pandas_all else []).drop(columns=["_id"])
221+
np.equal(out["a"], [1.0, np.nan, np.nan, 4.0])
222+
223+
def test_auto_schema_tz(self):
224+
schema = {
225+
"bool": "bool",
226+
"dt": "datetime64[ns]",
227+
"string": "str",
228+
}
229+
data = pd.DataFrame(
230+
data={
231+
"string": [None] + [str(i) for i in range(2)],
232+
"bool": [True for _ in range(3)],
233+
"dt": [
234+
datetime.datetime(1970 + i, 1, 1, tzinfo=timezone("US/Eastern"))
235+
for i in range(3)
236+
],
237+
},
238+
).astype(schema)
239+
self.coll.drop()
240+
codec_options = CodecOptions(tzinfo=timezone("US/Eastern"), tz_aware=True)
241+
res = write(self.coll.with_options(codec_options=codec_options), data)
242+
self.assertEqual(len(data), res.raw_result["insertedCount"])
243+
for func in [find_pandas_all, aggregate_pandas_all]:
244+
out = func(
245+
self.coll.with_options(codec_options=codec_options),
246+
{} if func == find_pandas_all else [],
247+
).drop(columns=["_id"])
248+
pd.testing.assert_frame_equal(data, out)
249+
191250

192251
class TestBSONTypes(PandasTestBase):
193252
@classmethod

0 commit comments

Comments
 (0)