Skip to content

Commit ea92e15

Browse files
authored
PYTHON-2922 Add support for custom projections for find_arrow_all (#40)
1 parent a5fc0ba commit ea92e15

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def find_arrow_all(collection, query, *, schema, **kwargs):
5858
context = PyMongoArrowContext.from_schema(
5959
schema, codec_options=collection.codec_options)
6060

61-
for opt in ('session', 'cursor_type', 'projection'):
61+
for opt in ('session', 'cursor_type'):
6262
if kwargs.pop(opt, None):
6363
warnings.warn(
6464
f'Ignoring option {opt!r} as it is not supported by '
6565
'PyMongoArrow', UserWarning, stacklevel=2)
6666

67-
kwargs['projection'] = schema._get_projection()
67+
kwargs.setdefault('projection', schema._get_projection())
6868
raw_batch_cursor = collection.find_raw_batches(
6969
query, **kwargs)
7070
for batch in raw_batch_cursor:

bindings/python/test/test_arrow.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ def test_find_simple(self):
7272
self.assertEqual(find_cmd.command['projection'],
7373
{'_id': True, 'data': True})
7474

75+
def test_find_projection(self):
76+
expected = Table.from_pydict(
77+
{'_id': [4, 3], 'data': [None, 60]},
78+
ArrowSchema([('_id', int32()), ('data', int64())]))
79+
projection = {'_id': True, 'data': {'$multiply': [2, '$data']}}
80+
table = self.run_find({'_id': {'$gt': 2}},
81+
schema=self.schema,
82+
sort=[('_id', DESCENDING)],
83+
projection=projection)
84+
self.assertEqual(table, expected)
85+
86+
find_cmd = self.cmd_listener.results['started'][-1]
87+
self.assertEqual(find_cmd.command_name, 'find')
88+
self.assertEqual(find_cmd.command['projection'], projection)
89+
7590
def test_find_multiple_batches(self):
7691
orig_method = self.coll.find_raw_batches
7792

@@ -107,13 +122,17 @@ def test_aggregate_simple(self):
107122
expected = Table.from_pydict(
108123
{'_id': [1, 2, 3, 4], 'data': [20, 40, 60, None]},
109124
ArrowSchema([('_id', int32()), ('data', int64())]))
125+
projection = {'_id': True, 'data': {'$multiply': [2, '$data']}}
110126
table = self.run_aggregate(
111-
[{'$project': {'_id': True, 'data': {'$multiply': [2, '$data']}}}],
127+
[{'$project': projection }],
112128
schema=self.schema)
113129
self.assertEqual(table, expected)
114130

115131
agg_cmd = self.cmd_listener.results['started'][-1]
116132
self.assertEqual(agg_cmd.command_name, 'aggregate')
133+
assert len(agg_cmd.command['pipeline']) == 2
134+
self.assertEqual(agg_cmd.command['pipeline'][0]['$project'],
135+
projection)
117136
self.assertEqual(agg_cmd.command['pipeline'][1]['$project'],
118137
{'_id': True, 'data': True})
119138

bindings/python/test/test_numpy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,16 @@ def test_aggregate_simple(self):
8181
expected = {
8282
'_id': np.array([1, 2, 3, 4], dtype=np.int32),
8383
'data': np.array([20, 40, 60, np.nan], dtype=np.float64)}
84-
84+
projection = {'_id': True, 'data': {'$multiply': [2, '$data']}}
8585
actual = aggregate_numpy_all(
86-
self.coll, [{'$project': {
87-
'_id': True, 'data': {'$multiply': [2, '$data']}}}],
86+
self.coll, [{'$project': projection}],
8887
schema=self.schema)
8988
self.assert_numpy_equal(actual, expected)
9089

9190
agg_cmd = self.cmd_listener.results['started'][-1]
9291
self.assertEqual(agg_cmd.command_name, 'aggregate')
92+
assert len(agg_cmd.command['pipeline']) == 2
93+
self.assertEqual(agg_cmd.command['pipeline'][0]['$project'],
94+
projection)
9395
self.assertEqual(agg_cmd.command['pipeline'][1]['$project'],
9496
{'_id': True, 'data': True})

bindings/python/test/test_pandas.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,16 @@ def test_aggregate_simple(self):
7171
expected = pd.DataFrame(
7272
data={'_id': [1, 2, 3, 4], 'data': [20, 40, 60, None]}).astype(
7373
{'_id': 'int32'})
74+
projection = {'_id': True, 'data': {'$multiply': [2, '$data']}}
7475
table = aggregate_pandas_all(
75-
self.coll, [{'$project': {
76-
'_id': True, 'data': {'$multiply': [2, '$data']}}}],
76+
self.coll, [{'$project': projection}],
7777
schema=self.schema)
7878
self.assertTrue(table.equals(expected))
7979

8080
agg_cmd = self.cmd_listener.results['started'][-1]
8181
self.assertEqual(agg_cmd.command_name, 'aggregate')
82+
assert len(agg_cmd.command['pipeline']) == 2
83+
self.assertEqual(agg_cmd.command['pipeline'][0]['$project'],
84+
projection)
8285
self.assertEqual(agg_cmd.command['pipeline'][1]['$project'],
8386
{'_id': True, 'data': True})

0 commit comments

Comments
 (0)