Skip to content

Commit ed3a161

Browse files
committed
further progress on buildermanager and context
1 parent 210dd6c commit ed3a161

File tree

5 files changed

+130
-99
lines changed

5 files changed

+130
-99
lines changed

bindings/python/pymongoarrow/context.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pyarrow import ListArray, StructArray, Table
15-
from pyarrow.types import is_struct
1615

1716
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap
1817

@@ -47,45 +46,45 @@ def process_bson_stream(self, stream):
4746
self.manager.process_bson_stream(stream, len(stream))
4847

4948
def finish(self):
50-
array_map = _parse_array_map(self.manager.finish())
49+
array_map = _parse_builder_map(self.manager.finish())
5150
arrays = list(array_map.values())
5251
if self.schema is not None:
5352
return Table.from_arrays(arrays=arrays, schema=self.schema.to_arrow())
5453
return Table.from_arrays(arrays=arrays, names=list(array_map.keys()))
5554

5655

57-
def _parse_array_map(array_map):
56+
def _parse_builder_map(builder_map):
5857
# Handle nested builders.
5958
to_remove = []
6059
# Traverse the builder map right to left.
61-
for key, value in reversed(array_map.items()):
62-
field = key.decode("utf-8")
63-
if value.type_marker == _BsonArrowTypes.document:
64-
full_names = [f"{field}.{name.decode('utf-8')}" for name in value]
65-
arrs = [array_map[c.encode("utf-8")] for c in full_names]
66-
array_map[field] = StructArray.from_arrays(arrs, names=value)
60+
for key, value in reversed(builder_map.items()):
61+
if value.type_marker == _BsonArrowTypes.document.value:
62+
names = value.finish()
63+
full_names = [f"{key}.{name}" for name in names]
64+
arrs = [builder_map[c] for c in full_names]
65+
builder_map[key] = StructArray.from_arrays(arrs, names=names)
6766
to_remove.extend(full_names)
68-
elif value.type_marker == _BsonArrowTypes.array:
69-
child_name = field + "[]"
67+
elif value.type_marker == _BsonArrowTypes.array.value:
68+
child_name = key + "[]"
7069
to_remove.append(child_name)
71-
child = array_map[child_name.encode("utf-8")]
72-
array_map[key] = ListArray.from_arrays(value, child)
70+
child = builder_map[child_name]
71+
builder_map[key] = ListArray.from_arrays(value.finish(), child)
72+
else:
73+
builder_map[key] = value.finish()
7374

74-
for field in to_remove:
75-
key = field.encode("utf-8")
76-
if key in array_map:
77-
del array_map[key]
75+
for key in to_remove:
76+
if key in builder_map:
77+
del builder_map[key]
7878

79-
return array_map
79+
return builder_map
8080

8181

8282
def _parse_types(str_type_map, schema_map, tzinfo):
8383
for fname, (ftype, arrow_type) in str_type_map.items():
84-
encoded_fname = fname.encode("utf-8")
85-
schema_map[encoded_fname] = ftype, arrow_type
84+
schema_map[fname] = ftype, arrow_type
8685

8786
# special-case nested builders
88-
if ftype == _BsonArrowTypes.document:
87+
if ftype == _BsonArrowTypes.document.value:
8988
# construct a sub type map here
9089
sub_type_map = {}
9190
for i in range(arrow_type.num_fields):
@@ -94,13 +93,10 @@ def _parse_types(str_type_map, schema_map, tzinfo):
9493
sub_type_map[sub_name] = field.type
9594
sub_type_map = _get_internal_typemap(sub_type_map)
9695
_parse_types(sub_type_map, schema_map, tzinfo)
97-
elif ftype == _BsonArrowTypes.array:
98-
if is_struct(arrow_type.value_type):
99-
# construct a sub type map here
100-
sub_type_map = {}
101-
for i in range(arrow_type.value_type.num_fields):
102-
field = arrow_type.value_type[i]
103-
sub_name = f"{fname}[].{field.name}"
104-
sub_type_map[sub_name] = field.type
105-
sub_type_map = _get_internal_typemap(sub_type_map)
106-
_parse_types(sub_type_map, schema_map, tzinfo)
96+
elif ftype == _BsonArrowTypes.array.value:
97+
sub_type_map = {}
98+
sub_name = f"{fname}[]"
99+
sub_value_type = arrow_type.value_type
100+
sub_type_map[sub_name] = sub_value_type
101+
sub_type_map = _get_internal_typemap(sub_type_map)
102+
_parse_types(sub_type_map, schema_map, tzinfo)

bindings/python/pymongoarrow/lib.pyx

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,22 @@ cdef class BuilderManager:
6767
self.has_schema = has_schema
6868
self.tzinfo = tzinfo
6969
self.count = 0
70-
self.builder_map = builder_map = {}
70+
self.builder_map = {}
7171
# Unpack the schema map.
7272
for fname, (ftype, arrow_type) in schema_map.items():
73-
encoded_fname = fname.encode("utf-8")
73+
name = fname.encode('utf-8')
7474
# special-case initializing builders for parameterized types
7575
if ftype == BSON_TYPE_DATE_TIME:
7676
if tzinfo is not None and arrow_type.tz is None:
7777
arrow_type = timestamp(arrow_type.unit, tz=tzinfo) # noqa: PLW2901
78-
builder_map[encoded_fname] = DatetimeBuilder(dtype=arrow_type)
78+
self.builder_map[name] = DatetimeBuilder(dtype=arrow_type)
7979
elif ftype == BSON_TYPE_BINARY:
80-
subtype = arrow_type.subtype
81-
builder_map[encoded_fname] = BinaryBuilder(subtype)
80+
self.builder_map[name] = BinaryBuilder(arrow_type.subtype)
8281
else:
83-
self.get_builder(encoded_fname, ftype, <bson_iter_t *>nullptr)
82+
# We only use the doc_iter for binary arrays, which are handled already.
83+
self.get_builder(name, ftype, <bson_iter_t *>nullptr)
8484

85-
cdef _ArrayBuilderBase get_builder(self, cstring key, bson_type_t value_t, bson_iter_t * doc_iter):
85+
cdef _ArrayBuilderBase get_builder(self, cstring key, bson_type_t value_t, bson_iter_t * doc_iter) except *:
8686
cdef _ArrayBuilderBase builder = None
8787
cdef bson_subtype_t subtype
8888
cdef const uint8_t *val_buf = NULL
@@ -108,6 +108,8 @@ cdef class BuilderManager:
108108
elif value_t == BSON_TYPE_ARRAY:
109109
builder = ListBuilder()
110110
elif value_t == BSON_TYPE_BINARY:
111+
if doc_iter == NULL:
112+
raise ValueError('Did not pass a doc_iter!')
111113
bson_iter_binary (doc_iter, &subtype,
112114
&val_buf_len, &val_buf)
113115
builder = BinaryBuilder(subtype)
@@ -147,7 +149,6 @@ cdef class BuilderManager:
147149
# Get the key and and value.
148150
key = bson_iter_key(doc_iter)
149151
value_t = bson_iter_type(doc_iter)
150-
print('handling', key, value_t)
151152

152153
# Get the appropriate full key.
153154
if parent_type == BSON_TYPE_ARRAY:
@@ -163,6 +164,8 @@ cdef class BuilderManager:
163164
else:
164165
full_key = key
165166

167+
print('handling', full_key, value_t)
168+
166169
# Get the builder.
167170
builder = <_ArrayBuilderBase>self.builder_map.get(full_key, None)
168171
if builder is None and not self.has_schema:
@@ -193,9 +196,9 @@ cdef class BuilderManager:
193196
if parent_type == BSON_TYPE_ARRAY:
194197
(<ListBuilder>self.builder_map[base_key]).append_count()
195198

196-
# Update our count.
197-
if builder.length() > self.count:
198-
self.count = builder.length()
199+
# Update our count for top level documents.
200+
if parent_type == 0:
201+
self.count += 1
199202

200203
cpdef void process_bson_stream(self, const uint8_t* bson_stream, size_t length):
201204
"""Process a bson byte stream."""
@@ -214,35 +217,35 @@ cdef class BuilderManager:
214217
bson_reader_destroy(stream_reader)
215218

216219
cpdef finish(self):
217-
"""Finish building the arrays."""
218-
cdef dict builder_map = self.builder_map
219-
cdef dict array_map = {}
220+
"""Finish appending to the builders."""
221+
cdef dict return_map = {}
220222
cdef bytes key
221223
cdef str field
222224
cdef _ArrayBuilderBase value
223225

226+
# Move the builders to a new dict with string keys.
227+
for key, value in self.builder_map.items():
228+
return_map[key.decode('utf-8')] = value
229+
224230
# Insert null fields.
225-
for key in list(builder_map):
226-
if builder_map[key] is None:
227-
builder_map[key] = NullBuilder(self.count)
231+
for field in list(return_map):
232+
if return_map[field] is None:
233+
return_map[field] = NullBuilder(self.count)
228234

229235
# Pad fields as needed.
230-
for key, value in builder_map.items():
231-
field = key.decode("utf-8")
232-
236+
for field, value in return_map.items():
233237
# If it isn't a list item, append nulls as needed.
234238
# For lists, the nulls are stored in the parent.
235239
if not field.endswith('[]'):
236240
if value.length() < self.count:
237241
value.append_nulls(self.count - value.length())
238242

239-
array_map[field] = value.finish()
240-
return array_map
243+
return return_map
241244

242245

243246
cdef class _ArrayBuilderBase:
244247
cdef:
245-
uint8_t type_marker
248+
public uint8_t type_marker
246249

247250
def append_values(self, values):
248251
for value in values:
@@ -656,7 +659,8 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
656659
self.field_map[field_name] = 1
657660

658661
def finish(self):
659-
return set((f.decode('utf-8') for f in self.field_map))
662+
# Fields must be in order if we were given a schema.
663+
return list(f.decode('utf-8') for f in self.field_map)
660664

661665

662666
cdef class ListBuilder(_ArrayBuilderBase):

bindings/python/pymongoarrow/types.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -260,23 +260,23 @@ def get_numpy_type(type):
260260

261261

262262
_TYPE_CHECKER_TO_INTERNAL_TYPE = {
263-
_atypes.is_int32: _BsonArrowTypes.int32,
264-
_atypes.is_int64: _BsonArrowTypes.int64,
265-
_atypes.is_float64: _BsonArrowTypes.double,
266-
_atypes.is_timestamp: _BsonArrowTypes.datetime,
267-
_atypes.is_null: _BsonArrowTypes.null,
268-
_is_objectid: _BsonArrowTypes.objectid,
269-
_is_decimal128: _BsonArrowTypes.decimal128,
270-
_is_binary: _BsonArrowTypes.binary,
271-
_is_code: _BsonArrowTypes.code,
272-
_atypes.is_string: _BsonArrowTypes.string,
273-
_atypes.is_boolean: _BsonArrowTypes.bool,
274-
_atypes.is_struct: _BsonArrowTypes.document,
275-
_atypes.is_list: _BsonArrowTypes.array,
276-
_atypes.is_date32: _BsonArrowTypes.date32,
277-
_atypes.is_date64: _BsonArrowTypes.date64,
278-
_atypes.is_large_string: _BsonArrowTypes.string,
279-
_atypes.is_large_list: _BsonArrowTypes.array,
263+
_atypes.is_int32: _BsonArrowTypes.int32.value,
264+
_atypes.is_int64: _BsonArrowTypes.int64.value,
265+
_atypes.is_float64: _BsonArrowTypes.double.value,
266+
_atypes.is_timestamp: _BsonArrowTypes.datetime.value,
267+
_atypes.is_null: _BsonArrowTypes.null.value,
268+
_is_objectid: _BsonArrowTypes.objectid.value,
269+
_is_decimal128: _BsonArrowTypes.decimal128.value,
270+
_is_binary: _BsonArrowTypes.binary.value,
271+
_is_code: _BsonArrowTypes.code.value,
272+
_atypes.is_string: _BsonArrowTypes.string.value,
273+
_atypes.is_boolean: _BsonArrowTypes.bool.value,
274+
_atypes.is_struct: _BsonArrowTypes.document.value,
275+
_atypes.is_list: _BsonArrowTypes.array.value,
276+
_atypes.is_date32: _BsonArrowTypes.date32.value,
277+
_atypes.is_date64: _BsonArrowTypes.date64.value,
278+
_atypes.is_large_string: _BsonArrowTypes.string.value,
279+
_atypes.is_large_list: _BsonArrowTypes.array.value,
280280
}
281281

282282

bindings/python/test/test_arrow.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
import pymongo
2424
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId
2525
from pyarrow import (
26-
DataType,
27-
FixedSizeBinaryType,
2826
Table,
2927
bool_,
3028
csv,
@@ -405,7 +403,7 @@ def inner(i):
405403
if use_none:
406404
inner_dict["null"] = None
407405
if nested_elem:
408-
inner_dict["list"] = [nested_elem]
406+
inner_dict["list_of_objs"] = [nested_elem]
409407
return inner_dict
410408

411409
if nested_elem:
@@ -480,6 +478,18 @@ def test_auto_schema_nested(self):
480478
for name in out.column_names:
481479
self.assertEqual(data[name], out[name].cast(data[name].type))
482480

481+
def test_auto_schema_nested_null(self):
482+
# Create table with random data of various types.
483+
_, data = self._create_nested_data(use_none=True)
484+
485+
self.coll.drop()
486+
res = write(self.coll, data)
487+
self.assertEqual(len(data), res.raw_result["insertedCount"])
488+
for func in [find_arrow_all, aggregate_arrow_all]:
489+
out = func(self.coll, {} if func == find_arrow_all else []).drop(["_id"])
490+
for name in out.column_names:
491+
self.assertEqual(data[name], out[name].cast(data[name].type))
492+
483493
def test_schema_nested_null(self):
484494
schema, data = self._create_nested_data(use_none=True)
485495

@@ -514,7 +524,13 @@ def test_auto_schema_first_list_null(self):
514524
{"a": ["str"]},
515525
{"a": []},
516526
]
517-
expected = pa.Table.from_pylist(docs)
527+
expected = pa.Table.from_pylist(
528+
[
529+
{"a": []},
530+
{"a": ["str"]},
531+
{"a": []},
532+
]
533+
)
518534
self._test_auto_schema_list(docs, expected)
519535

520536
def test_auto_schema_first_list_empty(self):
@@ -525,7 +541,7 @@ def test_auto_schema_first_list_empty(self):
525541
]
526542
expected = pa.Table.from_pylist(
527543
[
528-
{"a": None}, # TODO: We incorrectly set the first empty list to null.
544+
{"a": []},
529545
{"a": ["str"]},
530546
{"a": []},
531547
]
@@ -538,20 +554,30 @@ def test_auto_schema_first_list_element_null(self):
538554
{"a": [None, None, "str"]}, # Inferred schema should use the first non-null element.
539555
{"a": []},
540556
]
541-
expected = pa.Table.from_pylist(docs)
557+
expected = pa.Table.from_pylist(
558+
[
559+
{"a": []},
560+
{"a": ["str"]}, # Inferred schema should use the first non-null element.
561+
{"a": []},
562+
]
563+
)
542564
self._test_auto_schema_list(docs, expected)
543565

544-
@unittest.expectedFailure # TODO: Our inferred value for the first a.b field differs from pyarrow's.
545566
def test_auto_schema_first_embedded_list_null(self):
546567
docs = [
547568
{"a": {"b": None}},
548569
{"a": {"b": ["str"]}},
549570
{"a": {"b": []}},
550571
]
551-
expected = pa.Table.from_pylist(docs)
572+
expected = pa.Table.from_pylist(
573+
[
574+
{"a": {"b": []}},
575+
{"a": {"b": ["str"]}},
576+
{"a": {"b": []}},
577+
]
578+
)
552579
self._test_auto_schema_list(docs, expected)
553580

554-
@unittest.expectedFailure # TODO: Our inferred value for the first a.b field differs from pyarrow's.
555581
def test_auto_schema_first_embedded_doc_null(self):
556582
docs = [
557583
{"a": {"b": None}},
@@ -747,18 +773,10 @@ def test_nested_bson_extension_types(self):
747773
out = find_arrow_all(self.coll, {})
748774
obj_schema_type = out.field("obj").type
749775

750-
self.assertIsInstance(obj_schema_type.field("obj_id").type, FixedSizeBinaryType)
751-
self.assertIsInstance(obj_schema_type.field("dec_128").type, FixedSizeBinaryType)
752-
self.assertIsInstance(obj_schema_type.field("binary").type, DataType)
753-
self.assertIsInstance(obj_schema_type.field("code").type, DataType)
754-
755-
new_types = [ObjectIdType(), Decimal128Type(), BinaryType(0), CodeType()]
756-
new_names = [f.name for f in out["obj"].type]
757-
new_obj = out["obj"].cast(struct(zip(new_names, new_types)))
758-
self.assertIsInstance(new_obj.type[0].type, ObjectIdType)
759-
self.assertIsInstance(new_obj.type[1].type, Decimal128Type)
760-
self.assertIsInstance(new_obj.type[2].type, BinaryType)
761-
self.assertIsInstance(new_obj.type[3].type, CodeType)
776+
self.assertIsInstance(obj_schema_type.field("obj_id").type, ObjectIdType)
777+
self.assertIsInstance(obj_schema_type.field("dec_128").type, Decimal128Type)
778+
self.assertIsInstance(obj_schema_type.field("binary").type, BinaryType)
779+
self.assertIsInstance(obj_schema_type.field("code").type, CodeType)
762780

763781
def test_large_string_type(self):
764782
"""Tests pyarrow._large_string() DataType"""

0 commit comments

Comments
 (0)