Skip to content

Commit 745ff0c

Browse files
authored
ARROW-139 Fully support ObjectId extension type (#129)
1 parent 63557b8 commit 745ff0c

File tree

11 files changed

+201
-43
lines changed

11 files changed

+201
-43
lines changed

bindings/python/pymongoarrow/lib.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase):
365365
cdef shared_ptr[CArray] out
366366
with nogil:
367367
self.builder.get().Finish(&out)
368-
return pyarrow_wrap_array(out)
368+
return pyarrow_wrap_array(out).cast(ObjectIdType())
369369

370370
cdef shared_ptr[CFixedSizeBinaryBuilder] unwrap(self):
371371
return self.builder

bindings/python/pymongoarrow/pandas_types.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import pandas as pd
2323
import pyarrow as pa
24-
from bson import Binary
24+
from bson import Binary, ObjectId
2525
from pandas.api.extensions import (
2626
ExtensionArray,
2727
ExtensionDtype,
@@ -34,16 +34,9 @@ class PandasBSONDtype(ExtensionDtype):
3434

3535
na_value = np.nan
3636

37-
def __init__(self, subtype):
38-
self._subtype = subtype
39-
40-
@property
41-
def subtype(self) -> int:
42-
return self._subtype
43-
4437
@property
4538
def name(self) -> str:
46-
return f"bson_{self.type.__name__}[{self.subtype}]"
39+
return f"bson_{self.__class__.__name__}"
4740

4841
def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]) -> ExtensionArray:
4942

@@ -65,7 +58,7 @@ def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]) -> ExtensionAr
6558
if not pd.isna(val) and not isinstance(val, typ):
6659
val = typ(val)
6760
vals.append(val)
68-
arr = np.array(vals)
61+
arr = np.array(vals, dtype=object)
6962
# using _from_sequence to ensure None is converted to NA
7063
to_append = arr_type._from_sequence(arr, dtype=dtype)
7164
results.append(to_append)
@@ -75,13 +68,25 @@ def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]) -> ExtensionAr
7568
else:
7669
return arr_type(np.array([], dtype="object"))
7770

71+
@classmethod
72+
def construct_from_string(cls, string):
73+
if not isinstance(string, str):
74+
raise TypeError(f"'construct_from_string' expects a string, got {type(string)}")
75+
default = cls()
76+
if string != default.name:
77+
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
78+
return default
79+
7880

7981
class PandasBSONExtensionArray(ExtensionArray):
8082
"""The base class for Pandas BSON extension arrays."""
8183

84+
_default_dtype = None
85+
8286
def __init__(self, values, dtype, copy=False) -> None:
8387
if not isinstance(values, np.ndarray):
8488
raise TypeError("Need to pass a numpy array as values")
89+
dtype = dtype or self._default_dtype
8590
if dtype is None:
8691
raise ValueError("dtype must be a valid data type")
8792
for val in values:
@@ -132,7 +137,7 @@ def __len__(self) -> int:
132137
def isna(self):
133138
return np.array(
134139
[
135-
x is not None and not isinstance(x, self.dtype.type) and np.isnan(x)
140+
x is None or (x is not None and not isinstance(x, self.dtype.type) and np.isnan(x))
136141
for x in self.data
137142
],
138143
dtype=bool,
@@ -181,14 +186,25 @@ def _concat_same_type(cls, to_concat):
181186

182187

183188
@register_extension_dtype
184-
class PandasBSONBinary(PandasBSONDtype):
189+
class PandasBinary(PandasBSONDtype):
185190
"""A pandas extension type for BSON Binary data type."""
186191

187192
type = Binary
188193

194+
def __init__(self, subtype):
195+
self._subtype = subtype
196+
197+
@property
198+
def subtype(self) -> int:
199+
return self._subtype
200+
201+
@property
202+
def name(self) -> str:
203+
return f"bson_{self.type.__name__}[{self.subtype}]"
204+
189205
@classmethod
190-
def construct_array_type(cls) -> Type["PandasBSONArray"]:
191-
return PandasBSONArray
206+
def construct_array_type(cls) -> Type["PandasBinaryArray"]:
207+
return PandasBinaryArray
192208

193209
@classmethod
194210
def construct_from_string(cls, string):
@@ -202,10 +218,34 @@ def construct_from_string(cls, string):
202218
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
203219

204220

205-
class PandasBSONArray(PandasBSONExtensionArray):
221+
class PandasBinaryArray(PandasBSONExtensionArray):
206222
"""A pandas extension type for BSON Binary data arrays."""
207223

208224
def __arrow_array__(self, type=None):
209225
from pymongoarrow.types import BinaryType
210226

211227
return pa.array(self.data, type=BinaryType(self.dtype.subtype))
228+
229+
230+
@register_extension_dtype
231+
class PandasObjectId(PandasBSONDtype):
232+
"""A pandas extension type for BSON ObjectId data type."""
233+
234+
type = ObjectId
235+
236+
@classmethod
237+
def construct_array_type(cls) -> Type["PandasObjectIdArray"]:
238+
return PandasObjectIdArray
239+
240+
241+
class PandasObjectIdArray(PandasBSONExtensionArray):
242+
"""A pandas extension type for BSON Binary data arrays."""
243+
244+
@property
245+
def _default_dtype(self):
246+
return PandasObjectId()
247+
248+
def __arrow_array__(self, type=None):
249+
from pymongoarrow.types import ObjectIdType
250+
251+
return pa.array(self.data, type=ObjectIdType())

bindings/python/pymongoarrow/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
struct,
3232
timestamp,
3333
)
34-
from pymongoarrow.pandas_types import PandasBSONBinary
34+
from pymongoarrow.pandas_types import PandasBinary, PandasObjectId
3535

3636

3737
class _BsonArrowTypes(enum.Enum):
@@ -55,6 +55,8 @@ class _BsonArrowTypes(enum.Enum):
5555

5656
class ObjectIdScalar(ExtensionScalar):
5757
def as_py(self):
58+
if self.value is None:
59+
return None
5860
return ObjectId(self.value.as_py())
5961

6062

@@ -70,6 +72,9 @@ def __reduce__(self):
7072
def __arrow_ext_scalar_class__(self):
7173
return ObjectIdScalar
7274

75+
def to_pandas_dtype(self):
76+
return PandasObjectId()
77+
7378

7479
class Decimal128Scalar(ExtensionScalar):
7580
def as_py(self):
@@ -115,7 +120,7 @@ def __arrow_ext_scalar_class__(self):
115120
return BinaryScalar
116121

117122
def to_pandas_dtype(self):
118-
return PandasBSONBinary(self.subtype)
123+
return PandasBinary(self.subtype)
119124

120125

121126
# Internal Type Handling.

bindings/python/test/pandas_types/test_binary.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
from bson import Binary
1717
from pandas.tests.extension import base
18-
from pymongoarrow.pandas_types import PandasBSONArray, PandasBSONBinary
18+
from pymongoarrow.pandas_types import PandasBinary, PandasBinaryArray
1919

2020
try:
2121
base.BaseIndexTests
@@ -30,7 +30,7 @@ def make_datum():
3030

3131
@pytest.fixture
3232
def dtype():
33-
return PandasBSONBinary(10)
33+
return PandasBinary(10)
3434

3535

3636
def make_data():
@@ -45,24 +45,24 @@ def make_data():
4545

4646
@pytest.fixture
4747
def data(dtype):
48-
return PandasBSONArray(np.array(make_data(), dtype=object), dtype=dtype)
48+
return PandasBinaryArray(np.array(make_data(), dtype=object), dtype=dtype)
4949

5050

5151
@pytest.fixture
5252
def data_missing(dtype):
53-
return PandasBSONArray(np.array([np.nan, make_datum()], dtype=object), dtype=dtype)
53+
return PandasBinaryArray(np.array([np.nan, make_datum()], dtype=object), dtype=dtype)
5454

5555

5656
@pytest.fixture
5757
def data_for_sorting(dtype):
58-
return PandasBSONArray(
58+
return PandasBinaryArray(
5959
np.array([make_datum(), make_datum(), make_datum()], dtype=object), dtype=dtype
6060
)
6161

6262

6363
@pytest.fixture
6464
def data_missing_for_sorting(dtype):
65-
return PandasBSONArray(
65+
return PandasBinaryArray(
6666
np.array([make_datum(), np.nan, make_datum()], dtype=object), dtype=dtype
6767
)
6868

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2023-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
from bson import ObjectId
17+
from pandas.tests.extension import base
18+
from pymongoarrow.pandas_types import PandasObjectId, PandasObjectIdArray
19+
20+
try:
21+
base.BaseIndexTests
22+
except AttributeError:
23+
pytest.skip("Not available", allow_module_level=True)
24+
25+
26+
def make_datum():
27+
return ObjectId()
28+
29+
30+
@pytest.fixture
31+
def dtype():
32+
return PandasObjectId()
33+
34+
35+
def make_data():
36+
return (
37+
[make_datum() for _ in range(8)]
38+
+ [np.nan]
39+
+ [make_datum() for _ in range(88)]
40+
+ [np.nan]
41+
+ [make_datum(), make_datum()]
42+
)
43+
44+
45+
@pytest.fixture
46+
def data(dtype):
47+
return PandasObjectIdArray(np.array(make_data(), dtype=object), dtype=dtype)
48+
49+
50+
@pytest.fixture
51+
def data_missing(dtype):
52+
return PandasObjectIdArray(np.array([np.nan, make_datum()], dtype=object), dtype=dtype)
53+
54+
55+
@pytest.fixture
56+
def data_for_sorting(dtype):
57+
return PandasObjectIdArray(
58+
np.array([make_datum(), make_datum(), make_datum()], dtype=object), dtype=dtype
59+
)
60+
61+
62+
@pytest.fixture
63+
def data_missing_for_sorting(dtype):
64+
return PandasObjectIdArray(
65+
np.array([make_datum(), np.nan, make_datum()], dtype=object), dtype=dtype
66+
)
67+
68+
69+
class TestDtype(base.BaseDtypeTests):
70+
def test_is_not_string_type(self, data):
71+
# Override to not return a value, which raises a warning.
72+
super().test_is_not_string_type(data)
73+
74+
def test_is_not_object_type(self, data):
75+
# Override to not return a value, which raises a warning.
76+
super().test_is_not_object_type(data)
77+
78+
79+
class TestInterface(base.BaseInterfaceTests):
80+
pass
81+
82+
83+
class TestConstructors(base.BaseConstructorsTests):
84+
pass
85+
86+
87+
class TestGetitem(base.BaseGetitemTests):
88+
pass
89+
90+
91+
class TestSetitem(base.BaseSetitemTests):
92+
pass
93+
94+
95+
class TestIndex(base.BaseIndexTests):
96+
pass
97+
98+
99+
class TestMissing(base.BaseMissingTests):
100+
pass

0 commit comments

Comments
 (0)