Skip to content

Commit bbdfef0

Browse files
authored
ARROW-175 fix Incorrect projection on nested fields (#165)
1 parent 3b8c388 commit bbdfef0

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

bindings/python/pymongoarrow/schema.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import collections.abc as abc
1515

16+
from pyarrow import ListType, StructType
1617
from pymongoarrow.types import _normalize_typeid
1718

1819

@@ -62,10 +63,21 @@ def _normalize_mapping(mapping):
6263

6364
def _get_projection(self):
6465
projection = {"_id": False}
65-
for fname, _ in self.typemap.items():
66-
projection[fname] = True
66+
for fname, ftype in self.typemap.items():
67+
projection[fname] = self._get_field_projection_value(ftype)
6768
return projection
6869

70+
def _get_field_projection_value(self, ftype):
71+
value = True
72+
if isinstance(ftype, ListType):
73+
return self._get_field_projection_value(ftype.value_field.type)
74+
elif isinstance(ftype, StructType):
75+
projection = {}
76+
for nested_ftype in ftype:
77+
projection[nested_ftype.name] = True
78+
value = projection
79+
return value
80+
6981
def __eq__(self, other):
7082
if isinstance(other, type(self)):
7183
return self.typemap == other.typemap

bindings/python/test/test_arrow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,16 @@ def test_mixed_types_int64(self):
640640
with self.assertRaises(OverflowError):
641641
find_arrow_all(self.coll, {}, projection={"_id": 0}, schema=Schema({"a": int32()}))
642642

643+
def test_nested_contradicting_unused_schema(self):
644+
data = [{"obj": {"a": 1, "b": 1000000000}}, {"obj": {"a": 2, "b": 1.0e50}}]
645+
schema = Schema({"obj": {"a": int32()}})
646+
647+
self.coll.drop()
648+
self.coll.insert_many(data)
649+
for func in [find_arrow_all, aggregate_arrow_all]:
650+
out = func(self.coll, {} if func == find_arrow_all else [], schema=schema)
651+
self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}])
652+
643653

644654
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
645655
def run_find(self, *args, **kwargs):

bindings/python/test/test_schema.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from unittest import TestCase
1616

1717
from bson import Binary, Code, Decimal128, Int64, ObjectId
18-
from pyarrow import Table, float64, int64
18+
from pyarrow import Table, field, float64, int64, list_
1919
from pyarrow import schema as ArrowSchema
20-
from pyarrow import timestamp
20+
from pyarrow import struct, timestamp
2121
from pymongoarrow.schema import Schema
2222
from pymongoarrow.types import _TYPE_NORMALIZER_FACTORY
2323

@@ -70,3 +70,22 @@ def test_from_bson_units(self):
7070
def test_from_arrow_units(self):
7171
schema = Schema({"field1": int64(), "field2": timestamp("s")})
7272
self.assertEqual(schema.typemap, {"field1": int64(), "field2": timestamp("s")})
73+
74+
def test_nested_projection(self):
75+
schema = Schema({"_id": int64(), "obj": {"a": int64(), "b": int64()}})
76+
self.assertEqual(schema._get_projection(), {"_id": True, "obj": {"a": True, "b": True}})
77+
78+
def test_list_projection(self):
79+
schema = Schema(
80+
{"_id": int64(), "list": list_(struct([field("a", int64()), field("b", int64())]))}
81+
)
82+
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})
83+
84+
def test_list_of_list_projection(self):
85+
schema = Schema(
86+
{
87+
"_id": int64(),
88+
"list": list_(list_(struct([field("a", int64()), field("b", int64())]))),
89+
}
90+
)
91+
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})

0 commit comments

Comments
 (0)