Skip to content

Commit 1b070d4

Browse files
authored
ARROW-243 Handle column of fields with "null" values only (#241)
1 parent a314632 commit 1b070d4

File tree

6 files changed

+121
-36
lines changed

6 files changed

+121
-36
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
8484
against which to run the ``find`` operation.
8585
- `query`: A mapping containing the query to use for the find operation.
8686
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
87-
If the schema is not given, it will be inferred using the first
88-
document in the result set.
87+
If the schema is not given, it will be inferred using the data in the
88+
result set.
8989
9090
Additional keyword-arguments passed to this method will be passed
9191
directly to the underlying ``find`` operation.
@@ -122,8 +122,8 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
122122
against which to run the ``aggregate`` operation.
123123
- `pipeline`: A list of aggregation pipeline stages.
124124
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
125-
If the schema is not given, it will be inferred using the first
126-
document in the result set.
125+
If the schema is not given, it will be inferred using the data in the
126+
result set.
127127
128128
Additional keyword-arguments passed to this method will be passed
129129
directly to the underlying ``aggregate`` operation.
@@ -177,8 +177,8 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs):
177177
against which to run the ``find`` operation.
178178
- `query`: A mapping containing the query to use for the find operation.
179179
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
180-
If the schema is not given, it will be inferred using the first
181-
document in the result set.
180+
If the schema is not given, it will be inferred using the data in the
181+
result set.
182182
183183
Additional keyword-arguments passed to this method will be passed
184184
directly to the underlying ``find`` operation.
@@ -198,8 +198,8 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
198198
against which to run the ``find`` operation.
199199
- `pipeline`: A list of aggregation pipeline stages.
200200
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
201-
If the schema is not given, it will be inferred using the first
202-
document in the result set.
201+
If the schema is not given, it will be inferred using the data in the
202+
result set.
203203
204204
Additional keyword-arguments passed to this method will be passed
205205
directly to the underlying ``aggregate`` operation.
@@ -240,8 +240,8 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
240240
against which to run the ``find`` operation.
241241
- `query`: A mapping containing the query to use for the find operation.
242242
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
243-
If the schema is not given, it will be inferred using the first
244-
document in the result set.
243+
If the schema is not given, it will be inferred using the data in the
244+
result set.
245245
246246
Additional keyword-arguments passed to this method will be passed
247247
directly to the underlying ``find`` operation.
@@ -271,8 +271,8 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
271271
against which to run the ``find`` operation.
272272
- `query`: A mapping containing the query to use for the find operation.
273273
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
274-
If the schema is not given, it will be inferred using the first
275-
document in the result set.
274+
If the schema is not given, it will be inferred using the data in the
275+
result set.
276276
277277
Additional keyword-arguments passed to this method will be passed
278278
directly to the underlying ``aggregate`` operation.
@@ -338,8 +338,8 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):
338338
against which to run the ``find`` operation.
339339
- `query`: A mapping containing the query to use for the find operation.
340340
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
341-
If the schema is not given, it will be inferred using the first
342-
document in the result set.
341+
If the schema is not given, it will be inferred using the data in the
342+
result set.
343343
344344
Additional keyword-arguments passed to this method will be passed
345345
directly to the underlying ``find`` operation.
@@ -361,8 +361,8 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
361361
against which to run the ``find`` operation.
362362
- `pipeline`: A list of aggregation pipeline stages.
363363
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
364-
If the schema is not given, it will be inferred using the first
365-
document in the result set.
364+
If the schema is not given, it will be inferred using the data in the
365+
result set.
366366
367367
Additional keyword-arguments passed to this method will be passed
368368
directly to the underlying ``aggregate`` operation.

bindings/python/pymongoarrow/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Int32Builder,
3131
Int64Builder,
3232
ListBuilder,
33+
NullBuilder,
3334
ObjectIdBuilder,
3435
StringBuilder,
3536
)
@@ -49,6 +50,7 @@
4950
_BsonArrowTypes.code: CodeBuilder,
5051
_BsonArrowTypes.date32: Date32Builder,
5152
_BsonArrowTypes.date64: Date64Builder,
53+
_BsonArrowTypes.null: NullBuilder,
5254
}
5355
except ImportError:
5456
pass

bindings/python/pymongoarrow/lib.pyx

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N
6464
# Placeholder numbers for the date types.
6565
cdef uint8_t ARROW_TYPE_DATE32 = 100
6666
cdef uint8_t ARROW_TYPE_DATE64 = 101
67+
cdef uint8_t ARROW_TYPE_NULL = 102
6768

6869
_builder_type_map = {
6970
BSON_TYPE_INT32: Int32Builder,
@@ -80,6 +81,7 @@ _builder_type_map = {
8081
BSON_TYPE_CODE: CodeBuilder,
8182
ARROW_TYPE_DATE32: Date32Builder,
8283
ARROW_TYPE_DATE64: Date64Builder,
84+
ARROW_TYPE_NULL: NullBuilder
8385
}
8486

8587
_field_type_map = {
@@ -177,6 +179,7 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
177179
cdef Py_ssize_t count = 0
178180
cdef uint8_t byte_order_status = 0
179181
cdef map[cstring, void *] builder_map
182+
cdef map[cstring, void *] missing_builders
180183
cdef map[cstring, void*].iterator it
181184
cdef bson_subtype_t subtype
182185
cdef int32_t val32
@@ -197,6 +200,7 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
197200
cdef DocumentBuilder doc_builder
198201
cdef Date32Builder date32_builder
199202
cdef Date64Builder date64_builder
203+
cdef NullBuilder null_builder
200204

201205
# Build up a map of the builders.
202206
for key, value in context.builder_map.items():
@@ -219,10 +223,6 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
219223
builder = None
220224
if arr_value_builder is not None:
221225
builder = arr_value_builder
222-
else:
223-
it = builder_map.find(key)
224-
if it != builder_map.end():
225-
builder = <_ArrayBuilderBase>builder_map[key]
226226

227227
if builder is None:
228228
it = builder_map.find(key)
@@ -233,9 +233,16 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
233233
# Get the appropriate builder for the current field.
234234
value_t = bson_iter_type(&doc_iter)
235235
builder_type = _builder_type_map.get(value_t)
236+
237+
# Keep the key in missing builders until we find it.
236238
if builder_type is None:
239+
missing_builders[key] = <void *>None
237240
continue
238241

242+
it = missing_builders.find(key)
243+
if it != builder_map.end():
244+
missing_builders.erase(key)
245+
239246
# Handle the parameterized builders.
240247
if builder_type == DatetimeBuilder and context.tzinfo is not None:
241248
arrow_type = timestamp('ms', tz=context.tzinfo)
@@ -410,6 +417,9 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
410417
binary_builder.append_null()
411418
else:
412419
binary_builder.append_raw(<char*>val_buf, val_buf_len)
420+
elif ftype == ARROW_TYPE_NULL:
421+
null_builder = builder
422+
null_builder.append_null()
413423
else:
414424
raise PyMongoArrowError('unknown ftype {}'.format(ftype))
415425

@@ -422,6 +432,17 @@ cdef void process_raw_bson_stream(const uint8_t * docstream, size_t length, obje
422432
if len(builder) != count:
423433
builder.append_null()
424434
preincrement(it)
435+
436+
# Any missing fields that are left must be null fields.
437+
it = missing_builders.begin()
438+
while it != missing_builders.end():
439+
builder = NullBuilder()
440+
context.builder_map[key] = builder
441+
null_builder = builder
442+
for _ in range(count):
443+
null_builder.append_null()
444+
preincrement(it)
445+
425446
finally:
426447
bson_reader_destroy(stream_reader)
427448

@@ -724,6 +745,37 @@ cdef class Date32Builder(_ArrayBuilderBase):
724745
return self.builder
725746

726747

748+
cdef class NullBuilder(_ArrayBuilderBase):
749+
cdef:
750+
shared_ptr[CNullBuilder] builder
751+
752+
def __cinit__(self, MemoryPool memory_pool=None):
753+
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
754+
self.builder.reset(new CNullBuilder(pool))
755+
self.type_marker = ARROW_TYPE_NULL
756+
757+
cdef append_raw(self, void* value):
758+
self.builder.get().AppendNull()
759+
760+
cpdef append(self, value):
761+
self.builder.get().AppendNull()
762+
763+
cpdef append_null(self):
764+
self.builder.get().AppendNull()
765+
766+
def __len__(self):
767+
return self.builder.get().length()
768+
769+
cpdef finish(self):
770+
cdef shared_ptr[CArray] out
771+
with nogil:
772+
self.builder.get().Finish(&out)
773+
return pyarrow_wrap_array(out)
774+
775+
cdef shared_ptr[CNullBuilder] unwrap(self):
776+
return self.builder
777+
778+
727779
cdef class BoolBuilder(_ArrayBuilderBase):
728780
cdef:
729781
shared_ptr[CBooleanBuilder] builder
@@ -817,6 +869,8 @@ cdef object get_field_builder(object field, object tzinfo):
817869
field_builder = ListBuilder(field_type, tzinfo)
818870
elif _atypes.is_large_list(field_type):
819871
field_builder = ListBuilder(field_type, tzinfo)
872+
elif _atypes.is_null(field_type):
873+
field_builder = NullBuilder()
820874
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.objectid:
821875
field_builder = ObjectIdBuilder()
822876
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.decimal128:

bindings/python/pymongoarrow/libarrow.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ cdef extern from "arrow/builder.h" namespace "arrow" nogil:
4545
int32_t num_values()
4646
shared_ptr[CDataType] type()
4747

48+
cdef cppclass CNullBuilder" arrow::NullBuilder"(CArrayBuilder):
49+
CNullBuilder(CMemoryPool* pool)
50+
4851

4952
cdef extern from "arrow/type_fwd.h" namespace "arrow" nogil:
5053
shared_ptr[CDataType] fixed_size_binary(int32_t byte_width)

bindings/python/pymongoarrow/types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
float64,
2828
int64,
2929
list_,
30+
null,
3031
string,
3132
struct,
3233
timestamp,
@@ -55,6 +56,7 @@ class _BsonArrowTypes(enum.Enum):
5556
code = 12
5657
date32 = 13
5758
date64 = 14
59+
null = 15
5860

5961

6062
# Custom Extension Types.
@@ -260,6 +262,7 @@ def get_numpy_type(type):
260262
_atypes.is_int64: _BsonArrowTypes.int64,
261263
_atypes.is_float64: _BsonArrowTypes.double,
262264
_atypes.is_timestamp: _BsonArrowTypes.datetime,
265+
_atypes.is_null: _BsonArrowTypes.null,
263266
_is_objectid: _BsonArrowTypes.objectid,
264267
_is_decimal128: _BsonArrowTypes.decimal128,
265268
_is_binary: _BsonArrowTypes.binary,
@@ -276,7 +279,7 @@ def get_numpy_type(type):
276279

277280

278281
def _is_typeid_supported(typeid):
279-
return typeid in _TYPE_NORMALIZER_FACTORY
282+
return typeid in _TYPE_NORMALIZER_FACTORY or typeid is None
280283

281284

282285
def _normalize_typeid(typeid, field_name):
@@ -293,7 +296,10 @@ def _normalize_typeid(typeid, field_name):
293296
raise ValueError(msg)
294297
return list_(_normalize_typeid(typeid[0], "0"))
295298
if _is_typeid_supported(typeid):
296-
normalizer = _TYPE_NORMALIZER_FACTORY[typeid]
299+
if typeid is None: # noqa: SIM108
300+
normalizer = lambda _: null() # noqa: E731
301+
else:
302+
normalizer = _TYPE_NORMALIZER_FACTORY[typeid]
297303
return normalizer(typeid)
298304
msg = f"Unsupported type identifier {typeid} for field {field_name}"
299305
raise ValueError(msg)

bindings/python/test/test_arrow.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,23 +295,26 @@ def test_pymongo_error(self):
295295

296296
def _create_data(self):
297297
schema = {k.__name__: v(True) for k, v in _TYPE_NORMALIZER_FACTORY.items()}
298+
schema["null"] = pa.null()
298299
schema["Binary"] = BinaryType(10)
299300
schema["ObjectId"] = ObjectIdType()
300301
schema["Decimal128"] = Decimal128Type()
301302
schema["Code"] = CodeType()
303+
pydict = {
304+
"Int64": [i for i in range(2)],
305+
"float": [i for i in range(2)],
306+
"datetime": [i for i in range(2)],
307+
"str": [str(i) for i in range(2)],
308+
"int": [i for i in range(2)],
309+
"bool": [True, False],
310+
"null": [None for _ in range(2)],
311+
"Binary": [b"1", b"23"],
312+
"ObjectId": [ObjectId().binary, ObjectId().binary],
313+
"Decimal128": [Decimal128(str(i)).bid for i in range(2)],
314+
"Code": [str(i) for i in range(2)],
315+
}
302316
data = Table.from_pydict(
303-
{
304-
"Int64": [i for i in range(2)],
305-
"float": [i for i in range(2)],
306-
"datetime": [i for i in range(2)],
307-
"str": [str(i) for i in range(2)],
308-
"int": [i for i in range(2)],
309-
"bool": [True, False],
310-
"Binary": [b"1", b"23"],
311-
"ObjectId": [ObjectId().binary, ObjectId().binary],
312-
"Decimal128": [Decimal128(str(i)).bid for i in range(2)],
313-
"Code": [str(i) for i in range(2)],
314-
},
317+
pydict,
315318
ArrowSchema(schema),
316319
)
317320
return schema, data
@@ -355,8 +358,10 @@ def test_write_batching(self, mock):
355358
self.round_trip(data, Schema(schema), coll=self.coll)
356359
self.assertEqual(mock.call_count, 2)
357360

358-
def _create_nested_data(self, nested_elem=None):
361+
def _create_nested_data(self, nested_elem=None, use_none=False):
359362
schema = {k.__name__: v(0) for k, v in _TYPE_NORMALIZER_FACTORY.items()}
363+
if use_none:
364+
schema["null"] = pa.null()
360365
if nested_elem:
361366
schem_ent, nested_elem = nested_elem
362367
schema["list"] = list_(schem_ent)
@@ -379,10 +384,12 @@ def _create_nested_data(self, nested_elem=None):
379384
"date32": [date(2012, 1, 1) for i in range(3)],
380385
"date64": [date(2012, 1, 1) for i in range(3)],
381386
}
387+
if use_none:
388+
raw_data["null"] = [None for _ in range(3)]
382389

383390
def inner(i):
384391
inner_dict = dict(
385-
str=str(i),
392+
str=None if use_none and i == 0 else str(i),
386393
bool=bool(i),
387394
float=i + 0.1,
388395
Int64=i,
@@ -395,6 +402,8 @@ def inner(i):
395402
date32=date(2012, 1, 1),
396403
date64=date(2014, 1, 1),
397404
)
405+
if use_none:
406+
inner_dict["null"] = None
398407
if nested_elem:
399408
inner_dict["list"] = [nested_elem]
400409
return inner_dict
@@ -471,6 +480,17 @@ def test_auto_schema_nested(self):
471480
for name in out.column_names:
472481
self.assertEqual(data[name], out[name].cast(data[name].type))
473482

483+
def test_schema_nested_null(self):
484+
schema, data = self._create_nested_data(use_none=True)
485+
486+
self.coll.drop()
487+
res = write(self.coll, data)
488+
self.assertEqual(len(data), res.raw_result["insertedCount"])
489+
for func in [find_arrow_all, aggregate_arrow_all]:
490+
out = func(self.coll, {} if func == find_arrow_all else [], schema=Schema(schema))
491+
for name in out.column_names:
492+
self.assertEqual(data[name], out[name].cast(data[name].type))
493+
474494
def test_auto_schema(self):
475495
_, data = self._create_data()
476496
self.coll.drop()

0 commit comments

Comments
 (0)