Skip to content

Commit 7f4e298

Browse files
[ARROW-210] Add support for large_list and large_string PyArrow DataTypes
* ARROW-210 Initial commit for pyarrow large_list and large_string DataTypes * Updated and added further datetime tests * Added tests of large_list and large_string to test_arrow * Added docstrings * Removed completed todo * Removed reference to large_binary
1 parent 5f86218 commit 7f4e298

File tree

6 files changed

+134
-34
lines changed

6 files changed

+134
-34
lines changed

bindings/python/pymongoarrow/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,6 @@ def finish(self):
121121
for fname, builder in self.builder_map.items():
122122
arrays.append(builder.finish())
123123
names.append(fname.decode("utf-8"))
124+
if self.schema is not None:
125+
return Table.from_arrays(arrays=arrays, schema=self.schema.to_arrow())
124126
return Table.from_arrays(arrays=arrays, names=names)

bindings/python/pymongoarrow/lib.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,12 +710,16 @@ cdef object get_field_builder(object field, object tzinfo):
710710
field_builder = DatetimeBuilder(field_type)
711711
elif _atypes.is_string(field_type):
712712
field_builder = StringBuilder()
713+
elif _atypes.is_large_string(field_type):
714+
field_builder = StringBuilder()
713715
elif _atypes.is_boolean(field_type):
714716
field_builder = BoolBuilder()
715717
elif _atypes.is_struct(field_type):
716718
field_builder = DocumentBuilder(field_type, tzinfo)
717719
elif _atypes.is_list(field_type):
718720
field_builder = ListBuilder(field_type, tzinfo)
721+
elif _atypes.is_large_list(field_type):
722+
field_builder = ListBuilder(field_type, tzinfo)
719723
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.objectid:
720724
field_builder = ObjectIdBuilder()
721725
elif getattr(field_type, '_type_marker') == _BsonArrowTypes.decimal128:
@@ -799,8 +803,8 @@ cdef class ListBuilder(_ArrayBuilderBase):
799803
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
800804
cdef shared_ptr[CArrayBuilder] grandchild_builder
801805
self.dtype = dtype
802-
if not _atypes.is_list(dtype):
803-
raise ValueError("dtype must be a list_()")
806+
if not (_atypes.is_list(dtype) or _atypes.is_large_list(dtype)):
807+
raise ValueError("dtype must be a list_() or large_list()")
804808
self.context = context = PyMongoArrowContext(None, {})
805809
self.context.tzinfo = tzinfo
806810
field_builder = <StringBuilder>get_field_builder(self.dtype.value_type, tzinfo)

bindings/python/pymongoarrow/schema.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import collections.abc as abc
1515

16-
from pyarrow import ListType, StructType
16+
import pyarrow as pa
1717

1818
from pymongoarrow.types import _normalize_typeid
1919

@@ -73,9 +73,9 @@ def _get_projection(self):
7373

7474
def _get_field_projection_value(self, ftype):
7575
value = True
76-
if isinstance(ftype, ListType):
76+
if isinstance(ftype, pa.ListType):
7777
return self._get_field_projection_value(ftype.value_field.type)
78-
if isinstance(ftype, StructType):
78+
if isinstance(ftype, pa.StructType):
7979
projection = {}
8080
for nested_ftype in ftype:
8181
projection[nested_ftype.name] = True
@@ -86,3 +86,22 @@ def __eq__(self, other):
8686
if isinstance(other, type(self)):
8787
return self.typemap == other.typemap
8888
return False
89+
90+
@classmethod
91+
def from_arrow(cls, aschema: pa.Schema):
92+
"""Create a :class:`~pymongoarrow.schema.Schema` instance from a :class:`~pyarrow.Schema`
93+
94+
:Parameters:
95+
- `aschema`: PyArrow Schema
96+
"""
97+
self = cls({})
98+
for field in aschema:
99+
self.typemap[field.name] = field.type
100+
return self
101+
102+
def to_arrow(self):
103+
"""Output the Schema as an instance of class:`~pyarrow.Schema`."""
104+
fields = []
105+
for name, type_ in self.typemap.items():
106+
fields.append(pa.field(name=name, type=type_))
107+
return pa.schema(fields)

bindings/python/pymongoarrow/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def get_numpy_type(type):
266266
_atypes.is_boolean: _BsonArrowTypes.bool,
267267
_atypes.is_struct: _BsonArrowTypes.document,
268268
_atypes.is_list: _BsonArrowTypes.array,
269+
_atypes.is_large_string: _BsonArrowTypes.string,
270+
_atypes.is_large_list: _BsonArrowTypes.array,
269271
}
270272

271273

@@ -296,6 +298,7 @@ def _get_internal_typemap(typemap):
296298
for checker, internal_id in _TYPE_CHECKER_TO_INTERNAL_TYPE.items():
297299
if checker(ftype):
298300
internal_typemap[fname] = internal_id
301+
break
299302

300303
if fname not in internal_typemap:
301304
msg = f'Unsupported data type in schema for field "{fname}" of type "{ftype}"'

bindings/python/test/test_arrow.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
field,
3333
int32,
3434
int64,
35+
large_list,
36+
large_string,
3537
list_,
3638
string,
3739
struct,
@@ -707,6 +709,46 @@ def test_nested_bson_extension_types(self):
707709
self.assertIsInstance(new_obj.type[2].type, BinaryType)
708710
self.assertIsInstance(new_obj.type[3].type, CodeType)
709711

712+
def test_large_string_type(self):
713+
"""Tests pyarrow._large_string() DataType"""
714+
data = Table.from_pydict(
715+
{"string": ["A", "B", "C"], "large_string": ["C", "D", "E"]},
716+
ArrowSchema({"string": string(), "large_string": large_string()}),
717+
)
718+
self.round_trip(data, Schema({"string": str, "large_string": large_string()}))
719+
720+
def test_large_list_type(self):
721+
"""Tests pyarrow._large_list() DataType
722+
723+
1. Test large_list of large_string
724+
- with schema in query, one gets full roundtrip consistency
725+
- without schema, normal list and string will be inferred
726+
727+
2. Test nested as well
728+
"""
729+
730+
schema = ArrowSchema([field("_id", int32()), field("txns", large_list(large_string()))])
731+
732+
data = {
733+
"_id": [1, 2, 3, 4],
734+
"txns": [["A"], ["A", "B"], ["A", "B", "C"], ["A", "B", "C", "D"]],
735+
}
736+
table_orig = pa.Table.from_pydict(data, schema)
737+
self.coll.drop()
738+
res = write(self.coll, table_orig)
739+
# 1a.
740+
self.assertEqual(len(data["_id"]), res.raw_result["insertedCount"])
741+
table_schema = find_arrow_all(self.coll, {}, schema=Schema.from_arrow(schema))
742+
self.assertTrue(table_schema, table_orig)
743+
# 1b.
744+
table_none = find_arrow_all(self.coll, {}, schema=None)
745+
self.assertTrue(table_none.schema.types == [int32(), list_(string())])
746+
self.assertTrue(table_none.to_pydict() == data)
747+
748+
# 2. Test in sublist
749+
schema, data = self._create_nested_data((large_list(int32()), list(range(3))))
750+
self.round_trip(data, Schema(schema))
751+
710752

711753
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
712754
def run_find(self, *args, **kwargs):

bindings/python/test/test_datetime.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,38 +98,68 @@ def test_timezone_specified_in_schema(self):
9898
self.assertEqual(table, expected)
9999

100100
def test_timezone_specified_in_codec_options(self):
101-
# 1. When specified, CodecOptions.tzinfo will modify timestamp
102-
# type specifiers in the schema to inherit the specified timezone
103-
tz = pytz.timezone("US/Pacific")
104-
codec_options = CodecOptions(tz_aware=True, tzinfo=tz)
105-
expected = Table.from_pydict(
106-
{"_id": [1, 2], "data": self.expected_times},
107-
ArrowSchema([("_id", int32()), ("data", timestamp("ms", tz=tz))]),
101+
"""Test behavior of setting tzinfo CodecOptions in Collection.with_options.
102+
103+
When provided, timestamp type specifiers in the schema to inherit the specified timezone.
104+
Read values will maintain this information for timestamps whether schema is passed or not.
105+
106+
Note, this does not apply to datetimes.
107+
We also test here that if one asks for a different timezone upon reading,
108+
on returns the requested timezone.
109+
"""
110+
111+
# 1. We pass tzinfo to Collection.with_options, and same tzinfo in schema of find_arrow_all
112+
tz_west = pytz.timezone("US/Pacific")
113+
codec_options = CodecOptions(tz_aware=True, tzinfo=tz_west)
114+
coll_west = self.coll.with_options(codec_options=codec_options)
115+
116+
schema_west = ArrowSchema([("_id", int32()), ("data", timestamp("ms", tz=tz_west))])
117+
table_west = find_arrow_all(
118+
collection=coll_west,
119+
query={},
120+
schema=Schema.from_arrow(schema_west),
121+
sort=[("_id", ASCENDING)],
108122
)
109123

110-
schemas = [
111-
Schema({"_id": int32(), "data": timestamp("ms")}),
112-
Schema({"_id": int32(), "data": datetime}),
113-
]
114-
for schema in schemas:
115-
table = find_arrow_all(
116-
self.coll.with_options(codec_options=codec_options),
117-
{},
118-
schema=schema,
119-
sort=[("_id", ASCENDING)],
120-
)
124+
expected_west = Table.from_pydict(
125+
{"_id": [1, 2], "data": self.expected_times}, schema=schema_west
126+
)
127+
self.assertTrue(table_west.equals(expected_west))
121128

122-
self.assertEqual(table, expected)
129+
# 2. We pass tzinfo to Collection.with_options, but do NOT include a schema in find_arrow_all
130+
table_none = find_arrow_all(
131+
collection=coll_west,
132+
query={},
133+
schema=None,
134+
sort=[("_id", ASCENDING)],
135+
)
136+
self.assertTrue(table_none.equals(expected_west))
123137

124-
# 2. CodecOptions.tzinfo will be ignored when tzinfo is specified
125-
# in the original schema type specifier.
126-
tz_east = pytz.timezone("US/Eastern")
127-
codec_options = CodecOptions(tz_aware=True, tzinfo=tz_east)
128-
schema = Schema({"_id": int32(), "data": timestamp("ms", tz=tz)})
129-
table = find_arrow_all(
130-
self.coll.with_options(codec_options=codec_options),
131-
{},
132-
schema=schema,
138+
# 3. Now we pass a DIFFERENT timezone to the schema in find_arrow_all than we did to the Collection
139+
schema_east = Schema(
140+
{"_id": int32(), "data": timestamp("ms", tz=pytz.timezone("US/Eastern"))}
141+
)
142+
table_east = find_arrow_all(
143+
collection=coll_west,
144+
query={},
145+
schema=schema_east,
133146
sort=[("_id", ASCENDING)],
134147
)
135-
self.assertEqual(table, expected)
148+
# Confirm that we get the timezone we requested
149+
self.assertTrue(table_east.schema.types == [int32(), timestamp(unit="ms", tz="US/Eastern")])
150+
# Confirm that the times have been adjusted
151+
times_west = table_west["data"].to_pylist()
152+
times_east = table_east["data"].to_pylist()
153+
self.assertTrue(all([times_west[i] == times_east[i] for i in range(len(table_east))]))
154+
155+
# 4. Test behavior of datetime. Output will be pyarrow.timestamp("ms") without timezone
156+
schema_dt = Schema({"_id": int32(), "data": datetime})
157+
table_dt = find_arrow_all(
158+
collection=coll_west,
159+
query={},
160+
schema=schema_dt,
161+
sort=[("_id", ASCENDING)],
162+
)
163+
self.assertTrue(table_dt.schema.types == [int32(), timestamp(unit="ms")])
164+
times = table_dt["data"].to_pylist()
165+
self.assertTrue(times == self.expected_times)

0 commit comments

Comments
 (0)