Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions bindings/python/pymongoarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N
cdef class BuilderManager:
cdef:
dict builder_map
dict parent_types
dict parent_names
uint64_t count
bint has_schema
object tzinfo
Expand All @@ -69,6 +71,8 @@ cdef class BuilderManager:
self.tzinfo = tzinfo
self.count = 0
self.builder_map = {}
self.parent_names = {}
self.parent_types = {}
self.pool = default_memory_pool()
# Unpack the schema map.
for fname, (ftype, arrow_type) in schema_map.items():
Expand Down Expand Up @@ -146,6 +150,7 @@ cdef class BuilderManager:
cdef bson_iter_t child_iter
cdef uint64_t count = self.count
cdef _ArrayBuilderBase builder = None
cdef _ArrayBuilderBase parent_builder = None

while bson_iter_next(doc_iter):
# Get the key and and value.
Expand All @@ -156,12 +161,15 @@ cdef class BuilderManager:
if parent_type == BSON_TYPE_ARRAY:
full_key = base_key
full_key.append(b"[]")
self.parent_types[full_key] = BSON_TYPE_ARRAY

elif parent_type == BSON_TYPE_DOCUMENT:
full_key = base_key
full_key.append(b".")
full_key.append(key)
(<DocumentBuilder>self.builder_map[base_key]).add_field(key)
self.parent_types[full_key] = BSON_TYPE_DOCUMENT
self.parent_names[full_key] = base_key

else:
full_key = key
Expand All @@ -174,8 +182,13 @@ cdef class BuilderManager:
continue

# Append nulls to catch up.
# For lists, the nulls are stored in the parent.
# For list children, the nulls are stored in the parent.
if parent_type != BSON_TYPE_ARRAY:
# For document children, catch up with the parent doc.
# Root fields will use the base document count
if parent_type == BSON_TYPE_DOCUMENT:
parent_builder = <_ArrayBuilderBase>self.builder_map.get(base_key, None)
count = parent_builder.length() - 1
if count > builder.length():
status = builder.append_nulls_raw(count - builder.length())
if not status.ok():
Expand Down Expand Up @@ -222,27 +235,36 @@ cdef class BuilderManager:
cdef dict return_map = {}
cdef bytes key
cdef str field
cdef uint64_t count
cdef CStatus status
cdef _ArrayBuilderBase value
cdef _ArrayBuilderBase builder
cdef _ArrayBuilderBase parent_builder

# Move the builders to a new dict with string keys.
for key, value in self.builder_map.items():
return_map[key.decode('utf-8')] = value
for key, builder in self.builder_map.items():
return_map[key.decode('utf-8')] = builder

# Insert null fields.
for field in list(return_map):
if return_map[field] is None:
return_map[field] = NullBuilder(memory_pool=self.pool)

# Pad fields as needed.
for field, value in return_map.items():
# If it isn't a list item, append nulls as needed.
# For lists, the nulls are stored in the parent.
if not field.endswith('[]'):
if value.length() < self.count:
status = value.append_nulls_raw(self.count - value.length())
if not status.ok():
raise ValueError("Failed to append nulls to", field)
for field, builder in return_map.items():
# For list children, the nulls are stored in the parent.
key = field.encode('utf-8')
parent_type = self.parent_types.get(key, None)
if parent_type == BSON_TYPE_ARRAY:
continue
if parent_type == BSON_TYPE_DOCUMENT:
parent_builder = self.builder_map[self.parent_names[key]]
count = parent_builder.length()
else:
count = self.count
if builder.length() < count:
status = builder.append_nulls_raw(count - builder.length())
if not status.ok():
raise ValueError("Failed to append nulls to", field)

return return_map

Expand Down Expand Up @@ -688,13 +710,15 @@ cdef class ListBuilder(_ArrayBuilderBase):
self.type_marker = BSON_TYPE_ARRAY

cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t):
if value_t == BSON_TYPE_NULL:
return self.append_null_raw()
return self.builder.get().Append(self.count)

cpdef void append_count(self):
self.count += 1

cdef CStatus append_null_raw(self):
return self.builder.get().Append(self.count)
return self.builder.get().AppendNull()

cdef shared_ptr[CArrayBuilder] get_builder(self):
return <shared_ptr[CArrayBuilder]>self.builder
Expand Down
173 changes: 173 additions & 0 deletions bindings/python/test/nested_data_in.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
[{
"object1": {
"object11": {
"object111": {}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 13.2},
{"field11111": 41.6}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 3.9},
{"field11111": 69.5}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 147.2}
],
"list1112": [
{"field11121": "Barrier"},
{"field11121": "Barrier"},
{"field11121": "Barrier"},
{"field11121": "Barrier"}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 90.4}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 1.7},
{"field11111": 53.9}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 15.6}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 6.7},
{"field11111": 12.3}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 57.1}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 60.5}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 1.2}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {
"list1111": [
{"field11111": 14.9}
]
}
}
}
},
{
"object1": {
"object11": {
"object111": {}
}
}
},
{
"object1": {
"object11": {
"object111": {}
}
}
},
{
"object1": {
"object11": {
"object111": {}
}
}
}]
1 change: 1 addition & 0 deletions bindings/python/test/nested_data_out.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"col": null}, {"col": [{"field11111": 13.2}, {"field11111": 41.6}]}, {"col": [{"field11111": 3.9}, {"field11111": 69.5}]}, {"col": [{"field11111": 147.2}]}, {"col": null}, {"col": [{"field11111": 90.4}]}, {"col": [{"field11111": 1.7}, {"field11111": 53.9}]}, {"col": [{"field11111": 15.6}]}, {"col": null}, {"col": [{"field11111": 6.7}, {"field11111": 12.3}]}, {"col": [{"field11111": 57.1}]}, {"col": [{"field11111": 60.5}]}, {"col": [{"field11111": 1.2}]}, {"col": [{"field11111": 14.9}]}, {"col": null}, {"col": null}, {"col": null}]
35 changes: 35 additions & 0 deletions bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import tempfile
import unittest
import unittest.mock as mock
from datetime import date, datetime
from pathlib import Path
from test import client_context
from test.utils import AllowListEventListener, NullsTestMixin

Expand Down Expand Up @@ -56,6 +58,8 @@
ObjectIdType,
)

HERE = Path(__file__).absolute().parent


class ArrowApiTestMixin:
@classmethod
Expand Down Expand Up @@ -493,11 +497,42 @@ def test_schema_missing_field(self):
self.assertEqual(out["list_field"].to_pylist(), expected)

def test_schema_incorrect_data_type(self):
# From https://github.com/mongodb-labs/mongo-arrow/issues/260.
self.coll.delete_many({})
self.coll.insert_one({"x": {"y": 1}})
out = find_arrow_all(self.coll, {}, schema=Schema({"x": str}))
assert out.to_pylist() == [{"x": None}]

def test_schema_arrays_of_documents(self):
# From https://github.com/mongodb-labs/mongo-arrow/issues/258.
coll = self.coll
coll.delete_many({})
coll.insert_one({"list1": [{"field1": 13.2, "field2": 13.2}, {"field1": 41.6}]})
coll.insert_one({"list1": [{"field1": 1.6}]})
schema = Schema(
{"col": pa.list_(pa.struct({"field1": pa.float64(), "field2": pa.float64()}))}
)
df = aggregate_arrow_all(coll, [{"$project": {"col": "$list1"}}], schema=schema)
assert df.to_pylist() == [
{"col": [{"field1": 13.2, "field2": 13.2}, {"field1": 41.6, "field2": None}]},
{"col": [{"field1": 1.6, "field2": None}]},
]

def test_schema_arrays_of_documents_with_nulls(self):
# From https://github.com/mongodb-labs/mongo-arrow/issues/257.
coll = self.coll
coll.delete_many({})
with (HERE / "nested_data_in.json").open() as fid:
coll.insert_many(json.load(fid))
df = aggregate_arrow_all(
coll,
[{"$project": {"col": "$object1.object11.object111.list1111"}}],
schema=Schema({"col": pa.list_(pa.struct({"field11111": pa.float64()}))}),
)
with (HERE / "nested_data_out.json").open() as fid:
expected = json.load(fid)
assert df.to_pylist() == expected

def test_auto_schema_nested(self):
# Create table with random data of various types.
_, data = self._create_nested_data()
Expand Down
Loading