@@ -490,7 +490,6 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
490
490
cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self ):
491
491
return self .builder
492
492
493
-
494
493
cdef class Int32Builder(_ArrayBuilderBase):
495
494
cdef:
496
495
shared_ptr[CInt32Builder] builder
@@ -722,6 +721,8 @@ cdef object get_field_builder(object field, object tzinfo):
722
721
field_builder = Decimal128Builder()
723
722
elif getattr (field_type, ' _type_marker' ) == _BsonArrowTypes.binary:
724
723
field_builder = BinaryBuilder(field_type.subtype)
724
+ elif getattr (field_type, ' _type_marker' ) == _BsonArrowTypes.code:
725
+ field_builder = CodeBuilder()
725
726
else :
726
727
field_builder = StringBuilder()
727
728
return field_builder
@@ -732,6 +733,7 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
732
733
shared_ptr[CStructBuilder] builder
733
734
object dtype
734
735
object context
736
+ object builder_map
735
737
736
738
def __cinit__ (self , StructType dtype , tzinfo = None , MemoryPool memory_pool = None ):
737
739
cdef StringBuilder field_builder
@@ -744,11 +746,11 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
744
746
745
747
self .context = context = PyMongoArrowContext(None , {})
746
748
context.tzinfo = tzinfo
747
- builder_map = context.builder_map
749
+ self . builder_map = context.builder_map
748
750
749
751
for field in dtype:
750
752
field_builder = < StringBuilder> get_field_builder(field, tzinfo)
751
- builder_map[field.name.encode(' utf-8' )] = field_builder
753
+ self . builder_map[field.name.encode(' utf-8' )] = field_builder
752
754
c_field_builders.push_back(< shared_ptr[CArrayBuilder]> field_builder.builder)
753
755
754
756
self .builder.reset(new CStructBuilder(pyarrow_unwrap_data_type(dtype), pool, c_field_builders))
@@ -781,7 +783,30 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
781
783
cdef shared_ptr[CArray] out
782
784
with nogil:
783
785
self .builder.get().Finish(& out)
784
- return pyarrow_wrap_array(out)
786
+
787
+ struct_array = pyarrow_wrap_array(out)
788
+ for struct_def in struct_array:
789
+ new_types = []
790
+ new_names = list (struct_def.keys())
791
+ for fname, ftype in struct_def.items():
792
+ builder_instance = self .builder_map[fname.encode(' utf-8' )]
793
+ if isinstance (builder_instance, ObjectIdBuilder): # ObjectIdType
794
+ new_ftype = ObjectIdType()
795
+ new_types.append(new_ftype)
796
+ elif isinstance (builder_instance, Decimal128Builder): # Decimal128Type
797
+ new_ftype = Decimal128Type_()
798
+ new_types.append(new_ftype)
799
+ elif isinstance (builder_instance, BinaryBuilder): # BinaryType
800
+ new_ftype = BinaryType(self .dtype.field(fname).type.subtype)
801
+ new_types.append(new_ftype)
802
+ elif isinstance (builder_instance, CodeBuilder): # CodeType
803
+ new_ftype = CodeType()
804
+ new_types.append(new_ftype)
805
+ else :
806
+ new_types.append(ftype.type)
807
+
808
+ new_dtype = struct (zip (new_names, new_types))
809
+ return struct_array.cast(new_dtype)
785
810
786
811
cdef shared_ptr[CStructBuilder] unwrap(self ):
787
812
return self .builder
0 commit comments