Skip to content

Commit 888cbf8

Browse files
authored
ARROW-54 Make ObjectId an Extension Type (#50)
1 parent 4e0cd6b commit 888cbf8

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

bindings/python/docs/source/supported_types.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Support for additional types will be added in subsequent releases.
1818
* - String
1919
- :class:`py.str`, an instance of :class:`pyarrow.string`
2020
* - ObjectId
21-
- :class:`py.bytes`, :class:`bson.ObjectId`, an instance of :class:`pyarrow.FixedSizeBinaryScalar`
21+
- :class:`py.bytes`, :class:`bson.ObjectId`, an instance of :class:`pymongoarrow.types.ObjectIdType`, an instance of :class:`pyarrow.FixedSizeBinaryScalar`
2222
* - 64-bit binary floating point
2323
- :class:`py.float`, an instance of :meth:`pyarrow.float64`
2424
* - 32-bit integer

bindings/python/pymongoarrow/types.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from bson import Int64, ObjectId
1818

19-
from pyarrow import timestamp, float64, int64, int32, string
19+
from pyarrow import timestamp, binary, float64, int64, int32, string
20+
from pyarrow import PyExtensionType
2021
from pyarrow import DataType as _ArrowDataType
2122
import pyarrow.types as _atypes
2223

@@ -30,20 +31,37 @@ class _BsonArrowTypes(enum.Enum):
3031
string = 6
3132

3233

34+
# Custom Extension Types.
35+
# See https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types
36+
# for details.
37+
38+
class ObjectIdType(PyExtensionType):
39+
_type_marker = _BsonArrowTypes.objectid
40+
41+
def __init__(self):
42+
super().__init__(binary(12))
43+
44+
def __reduce__(self):
45+
return ObjectIdType, ()
46+
47+
48+
# Internal Type Handling.
49+
50+
def _is_objectid(obj):
51+
type_marker = getattr(obj, '_type_marker', '')
52+
return type_marker == ObjectIdType._type_marker
53+
54+
3355
_TYPE_NORMALIZER_FACTORY = {
3456
Int64: lambda _: int64(),
3557
float: lambda _: float64(),
3658
int: lambda _: int64(),
3759
datetime: lambda _: timestamp('ms'), # TODO: add tzinfo support
38-
ObjectId: lambda _: ObjectId,
60+
ObjectId: lambda _: ObjectIdType(),
3961
str: lambda: string(),
4062
}
4163

4264

43-
def _is_objectid(obj):
44-
return obj == ObjectId
45-
46-
4765
_TYPE_CHECKER_TO_INTERNAL_TYPE = {
4866
_atypes.is_int32: _BsonArrowTypes.int32,
4967
_atypes.is_int64: _BsonArrowTypes.int64,
@@ -74,15 +92,9 @@ def _get_internal_typemap(typemap):
7492
internal_typemap = {}
7593
for fname, ftype in typemap.items():
7694
for checker, internal_id in _TYPE_CHECKER_TO_INTERNAL_TYPE.items():
77-
# Catch error where the pyarrow checkers are looking for an `id`
78-
# attribute that might not exist on non-pyarrow types
79-
# (like ObjectId). For example, `is_int32()` checks for
80-
# `t.id == lib.Type_INT32`.
81-
try:
82-
if checker(ftype):
83-
internal_typemap[fname] = internal_id
84-
except AttributeError:
85-
pass
95+
if checker(ftype):
96+
internal_typemap[fname] = internal_id
97+
8698
if fname not in internal_typemap:
8799
raise ValueError('Unsupported data type in schema for ' +
88100
f'field "{fname}" of type "{ftype}"')

bindings/python/test/test_bson.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
from bson import encode, InvalidBSON
1717

18-
import pyarrow
18+
import pyarrow as pa
1919
from pymongoarrow.context import PyMongoArrowContext
2020
from pymongoarrow.lib import process_bson_stream
2121
from pymongoarrow.schema import Schema
22-
from pymongoarrow.types import int32, int64, string, ObjectId
22+
from pymongoarrow.types import int32, int64, string, ObjectId, ObjectIdType
2323

2424

2525
class TestBsonToArrowConversionBase(TestCase):
@@ -105,7 +105,7 @@ class TestUnsupportedDataType(TestBsonToArrowConversionBase):
105105
def test_simple(self):
106106
schema = Schema({'_id': ObjectId,
107107
'data': int64(),
108-
'fake': pyarrow.float16() })
108+
'fake': pa.float16() })
109109
msg = ("Unsupported data type in schema for field " +
110110
'"fake" of type "halffloat"')
111111
with self.assertRaisesRegex(ValueError, msg):
@@ -115,7 +115,7 @@ def test_simple(self):
115115
class TestNonAsciiFieldName(TestBsonToArrowConversionBase):
116116

117117
def setUp(self):
118-
self.schema = Schema({'_id': ObjectId,
118+
self.schema = Schema({'_id': ObjectIdType(),
119119
'dätá': int64()})
120120
self.context = PyMongoArrowContext.from_schema(
121121
self.schema)
@@ -132,3 +132,25 @@ def test_simple(self):
132132
}
133133

134134
self._run_test(docs, as_dict)
135+
136+
137+
class TestSerializeExtensions(TestCase):
138+
# Follows example in
139+
# https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types
140+
141+
def serialize_array(self, arr):
142+
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
143+
sink = pa.BufferOutputStream()
144+
with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
145+
writer.write_batch(batch)
146+
buf = sink.getvalue()
147+
with pa.ipc.open_stream(buf) as reader:
148+
result = reader.read_all()
149+
return result.column('ext')
150+
151+
def test_object_id_type(self):
152+
oids = [ObjectId().binary for _ in range(4)]
153+
storage_array = pa.array(oids, pa.binary(12))
154+
arr = pa.ExtensionArray.from_storage(ObjectIdType(), storage_array)
155+
result = self.serialize_array(arr)
156+
assert result.type._type_marker == ObjectIdType._type_marker

0 commit comments

Comments
 (0)