Skip to content

Commit 1cde144

Browse files
authored
ARROW-176 Nested extension objects are not handled in auto schema (#166)
1 parent bbdfef0 commit 1cde144

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

bindings/python/pymongoarrow/lib.pyx

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
490490
cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self):
491491
return self.builder
492492

493-
494493
cdef class Int32Builder(_ArrayBuilderBase):
495494
cdef:
496495
shared_ptr[CInt32Builder] builder
@@ -722,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo):
722721
field_builder = Decimal128Builder()
723722
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.binary:
724723
field_builder = BinaryBuilder(field_type.subtype)
724+
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.code:
725+
field_builder = CodeBuilder()
725726
else:
726727
field_builder = StringBuilder()
727728
return field_builder
@@ -732,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
732733
shared_ptr[CStructBuilder] builder
733734
object dtype
734735
object context
736+
object builder_map
735737

736738
def __cinit__(self, StructType dtype, tzinfo=None, MemoryPool memory_pool=None):
737739
cdef StringBuilder field_builder
@@ -744,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
744746

745747
self.context = context = PyMongoArrowContext(None, {})
746748
context.tzinfo = tzinfo
747-
builder_map = context.builder_map
749+
self.builder_map = context.builder_map
748750

749751
for field in dtype:
750752
field_builder = <StringBuilder>get_field_builder(field, tzinfo)
751-
builder_map[field.name.encode('utf-8')] = field_builder
753+
self.builder_map[field.name.encode('utf-8')] = field_builder
752754
c_field_builders.push_back(<shared_ptr[CArrayBuilder]>field_builder.builder)
753755

754756
self.builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders))
@@ -781,7 +783,30 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
781783
cdef shared_ptr[CArray] out
782784
with nogil:
783785
self.builder.get().Finish(&out)
784-
return pyarrow_wrap_array(out)
786+
787+
struct_array = pyarrow_wrap_array(out)
788+
for struct_def in struct_array:
789+
new_types = []
790+
new_names = list(struct_def.keys())
791+
for fname, ftype in struct_def.items():
792+
builder_instance = self.builder_map[fname.encode('utf-8')]
793+
if isinstance(builder_instance, ObjectIdBuilder): # ObjectIdType
794+
new_ftype = ObjectIdType()
795+
new_types.append(new_ftype)
796+
elif isinstance(builder_instance, Decimal128Builder): # Decimal128Type
797+
new_ftype = Decimal128Type_()
798+
new_types.append(new_ftype)
799+
elif isinstance(builder_instance, BinaryBuilder): # BinaryType
800+
new_ftype = BinaryType(self.dtype.field(fname).type.subtype)
801+
new_types.append(new_ftype)
802+
elif isinstance(builder_instance, CodeBuilder): # CodeType
803+
new_ftype = CodeType()
804+
new_types.append(new_ftype)
805+
else:
806+
new_types.append(ftype.type)
807+
808+
new_dtype = struct(zip(new_names, new_types))
809+
return struct_array.cast(new_dtype)
785810

786811
cdef shared_ptr[CStructBuilder] unwrap(self):
787812
return self.builder

bindings/python/test/test_arrow.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pyarrow
2323
import pymongo
24-
from bson import Binary, CodecOptions, Decimal128, ObjectId
24+
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId
2525
from pyarrow import Table, bool_, csv, decimal256, field, int32, int64, list_
2626
from pyarrow import schema as ArrowSchema
2727
from pyarrow import string, struct, timestamp
@@ -650,6 +650,26 @@ def test_nested_contradicting_unused_schema(self):
650650
out = func(self.coll, {} if func == find_arrow_all else [], schema=schema)
651651
self.assertEqual(out["obj"].to_pylist(), [{"a": 1}, {"a": 2}])
652652

653+
def test_nested_bson_extension_types(self):
654+
data = {
655+
"obj": {
656+
"obj_id": ObjectId(),
657+
"dec_128": Decimal128("0.0005"),
658+
"binary": Binary(b"123"),
659+
"code": Code(""),
660+
}
661+
}
662+
663+
self.coll.drop()
664+
self.coll.insert_one(data)
665+
out = find_arrow_all(self.coll, {})
666+
obj_schema_type = out.field("obj").type
667+
668+
self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType)
669+
self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type)
670+
self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType)
671+
self.assertIsInstance(obj_schema_type.field("code").type, CodeType)
672+
653673

654674
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
655675
def run_find(self, *args, **kwargs):

0 commit comments

Comments
 (0)