Skip to content

Commit f03474e

Browse files
authored
ARROW-144 Support BSON Code Type (#145)
1 parent 398d222 commit f03474e

File tree

10 files changed

+243
-15
lines changed

10 files changed

+243
-15
lines changed

bindings/python/pymongoarrow/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pymongoarrow.lib import (
1717
BinaryBuilder,
1818
BoolBuilder,
19+
CodeBuilder,
1920
DatetimeBuilder,
2021
Decimal128Builder,
2122
DocumentBuilder,
@@ -40,6 +41,7 @@
4041
_BsonArrowTypes.document: DocumentBuilder,
4142
_BsonArrowTypes.array: ListBuilder,
4243
_BsonArrowTypes.binary: BinaryBuilder,
44+
_BsonArrowTypes.code: CodeBuilder,
4345
}
4446

4547

bindings/python/pymongoarrow/lib.pyx

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ from pyarrow.lib import (
3333

3434
from pymongoarrow.errors import InvalidBSON, PyMongoArrowError
3535
from pymongoarrow.context import PyMongoArrowContext
36-
from pymongoarrow.types import _BsonArrowTypes, _atypes, ObjectIdType, Decimal128Type as Decimal128Type_, BinaryType
36+
from pymongoarrow.types import _BsonArrowTypes, _atypes, ObjectIdType, Decimal128Type as Decimal128Type_, BinaryType, CodeType
3737

3838
# Cython imports
3939
from cpython cimport PyBytes_Size, object
4040
from cython.operator cimport dereference
4141
from libcpp cimport bool as cbool
4242
from libcpp.map cimport map
43+
from libc.string cimport strlen
4344
from libcpp.vector cimport vector
4445
from pyarrow.lib cimport *
4546
from pymongoarrow.libarrow cimport *
@@ -69,7 +70,8 @@ _builder_type_map = {
6970
BSON_TYPE_DOCUMENT: DocumentBuilder,
7071
BSON_TYPE_DECIMAL128: Decimal128Builder,
7172
BSON_TYPE_ARRAY: ListBuilder,
72-
BSON_TYPE_BINARY: BinaryBuilder
73+
BSON_TYPE_BINARY: BinaryBuilder,
74+
BSON_TYPE_CODE: CodeBuilder,
7375
}
7476

7577
_field_type_map = {
@@ -80,8 +82,10 @@ _field_type_map = {
8082
BSON_TYPE_UTF8: string(),
8183
BSON_TYPE_BOOL: bool_(),
8284
BSON_TYPE_DECIMAL128: Decimal128Type_(),
85+
BSON_TYPE_CODE: CodeType(),
8386
}
8487

88+
8589
cdef extract_field_dtype(bson_iter_t * doc_iter, bson_iter_t * child_iter, bson_type_t value_t, context):
8690
"""Get the appropriate data type for a specific field"""
8791
cdef const uint8_t *val_buf = NULL
@@ -148,7 +152,7 @@ def process_bson_stream(bson_stream, context, arr_value_builder=None):
148152
cdef uint32_t val_buf_len = 0
149153
cdef bson_decimal128_t dec128
150154
cdef bson_type_t value_t
151-
cdef const char * bson_str
155+
cdef const char * bson_str = NULL
152156
cdef StructType struct_dtype
153157
cdef const bson_t * doc = NULL
154158
cdef bson_iter_t doc_iter
@@ -171,6 +175,7 @@ def process_bson_stream(bson_stream, context, arr_value_builder=None):
171175
t_array = _BsonArrowTypes.array
172176
t_binary = _BsonArrowTypes.binary
173177
t_decimal128 = _BsonArrowTypes.decimal128
178+
t_code = _BsonArrowTypes.code
174179

175180
# initialize count to current length of builders
176181
for _, builder in builder_map.items():
@@ -256,6 +261,12 @@ def process_bson_stream(bson_stream, context, arr_value_builder=None):
256261
builder.append(<bytes>(bson_str)[:str_len])
257262
else:
258263
builder.append_null()
264+
elif ftype == t_code:
265+
if value_t == BSON_TYPE_CODE:
266+
bson_str = bson_iter_code(&doc_iter, &str_len)
267+
builder.append(<bytes>(bson_str)[:str_len])
268+
else:
269+
builder.append_null()
259270
elif ftype == t_decimal128:
260271
if value_t == BSON_TYPE_DECIMAL128:
261272
bson_iter_decimal128(&doc_iter, &dec128)
@@ -359,6 +370,16 @@ cdef class StringBuilder(_ArrayBuilderBase):
359370
return self.builder
360371

361372

373+
cdef class CodeBuilder(StringBuilder):
374+
type_marker = _BsonArrowTypes.code
375+
376+
cpdef finish(self):
377+
cdef shared_ptr[CArray] out
378+
with nogil:
379+
self.builder.get().Finish(&out)
380+
return pyarrow_wrap_array(out).cast(CodeType())
381+
382+
362383
cdef class ObjectIdBuilder(_ArrayBuilderBase):
363384
type_marker = _BsonArrowTypes.objectid
364385
cdef:

bindings/python/pymongoarrow/libbson.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ cdef extern from "<bson/bson.h>":
130130

131131
bint bson_iter_decimal128 (const bson_iter_t *iter, bson_decimal128_t *dec)
132132

133+
const char * bson_iter_code (const bson_iter_t *iter, # IN
134+
uint32_t *length) # OUT
135+
133136
bint bson_iter_recurse (const bson_iter_t *iter, # IN
134137
bson_iter_t *child) # OUT
135138

bindings/python/pymongoarrow/pandas_types.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import pandas as pd
2323
import pyarrow as pa
24-
from bson import Binary, Decimal128, ObjectId
24+
from bson import Binary, Code, Decimal128, ObjectId
2525
from pandas.api.extensions import (
2626
ExtensionArray,
2727
ExtensionDtype,
@@ -262,7 +262,7 @@ def __arrow_array__(self, type=None):
262262

263263
@register_extension_dtype
264264
class PandasDecimal128(PandasBSONDtype):
265-
"""A pandas extension type for BSON ObjectId data type."""
265+
"""A pandas extension type for BSON Decimal128 data type."""
266266

267267
type = Decimal128
268268

@@ -282,3 +282,33 @@ def __arrow_array__(self, type=None):
282282
from pymongoarrow.types import Decimal128Type
283283

284284
return pa.array(self.data, type=Decimal128Type())
285+
286+
287+
@register_extension_dtype
288+
class PandasCode(PandasBSONDtype):
289+
"""A pandas extension type for BSON Code data type."""
290+
291+
type = Code
292+
293+
@classmethod
294+
def construct_array_type(cls) -> Type["PandasCodeArray"]:
295+
return PandasCodeArray
296+
297+
298+
class PandasCodeArray(PandasBSONExtensionArray):
299+
"""A pandas extension type for BSON Code data arrays."""
300+
301+
@property
302+
def _default_dtype(self):
303+
return PandasCode()
304+
305+
def __eq__(self, other):
306+
# Code types do not support element-wise comparison.
307+
if isinstance(other, Code):
308+
other = np.array(other, dtype=object)
309+
return super().__eq__(other)
310+
311+
def __arrow_array__(self, type=None):
312+
from pymongoarrow.types import CodeType
313+
314+
return pa.array(self.data, type=CodeType())

bindings/python/pymongoarrow/types.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import pyarrow as pa
1919
import pyarrow.types as _atypes
20-
from bson import Binary, Decimal128, Int64, ObjectId
20+
from bson import Binary, Code, Decimal128, Int64, ObjectId
2121
from pyarrow import DataType as _ArrowDataType
2222
from pyarrow import (
2323
ExtensionScalar,
@@ -31,7 +31,12 @@
3131
struct,
3232
timestamp,
3333
)
34-
from pymongoarrow.pandas_types import PandasBinary, PandasDecimal128, PandasObjectId
34+
from pymongoarrow.pandas_types import (
35+
PandasBinary,
36+
PandasCode,
37+
PandasDecimal128,
38+
PandasObjectId,
39+
)
3540

3641

3742
class _BsonArrowTypes(enum.Enum):
@@ -46,18 +51,23 @@ class _BsonArrowTypes(enum.Enum):
4651
document = 9
4752
array = 10
4853
binary = 11
54+
code = 12
4955

5056

5157
# Custom Extension Types.
5258
# See https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types
5359
# for details.
5460

5561

56-
class ObjectIdScalar(ExtensionScalar):
62+
class BSONExtensionScalar(ExtensionScalar):
5763
def as_py(self):
5864
if self.value is None:
5965
return None
60-
return ObjectId(self.value.as_py())
66+
return self._bson_class(self.value.as_py())
67+
68+
69+
class ObjectIdScalar(BSONExtensionScalar):
70+
_bson_class = ObjectId
6171

6272

6373
class ObjectIdType(PyExtensionType):
@@ -128,6 +138,26 @@ def to_pandas_dtype(self):
128138
return PandasBinary(self.subtype)
129139

130140

141+
class CodeScalar(BSONExtensionScalar):
142+
_bson_class = Code
143+
144+
145+
class CodeType(PyExtensionType):
146+
_type_marker = _BsonArrowTypes.code
147+
148+
def __init__(self):
149+
super().__init__(string())
150+
151+
def __reduce__(self):
152+
return CodeType, ()
153+
154+
def __arrow_ext_scalar_class__(self):
155+
return CodeScalar
156+
157+
def to_pandas_dtype(self):
158+
return PandasCode()
159+
160+
131161
# Internal Type Handling.
132162

133163

@@ -146,6 +176,11 @@ def _is_binary(obj):
146176
return type_marker == BinaryType._type_marker
147177

148178

179+
def _is_code(obj):
180+
type_marker = getattr(obj, "_type_marker", "")
181+
return type_marker == CodeType._type_marker
182+
183+
149184
_TYPE_NORMALIZER_FACTORY = {
150185
Int64: lambda _: int64(),
151186
float: lambda _: float64(),
@@ -159,6 +194,7 @@ def _is_binary(obj):
159194
str: lambda _: string(),
160195
bool: lambda _: bool_(),
161196
Binary: lambda subtype: BinaryType(subtype),
197+
Code: lambda _: CodeType(),
162198
}
163199

164200

@@ -188,6 +224,7 @@ def get_numpy_type(type):
188224
_is_objectid: _BsonArrowTypes.objectid,
189225
_is_decimal128: _BsonArrowTypes.decimal128,
190226
_is_binary: _BsonArrowTypes.binary,
227+
_is_code: _BsonArrowTypes.code,
191228
_atypes.is_string: _BsonArrowTypes.string,
192229
_atypes.is_boolean: _BsonArrowTypes.bool,
193230
_atypes.is_struct: _BsonArrowTypes.document,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2023-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
from bson import Code
17+
from pandas.tests.extension import base
18+
from pymongoarrow.pandas_types import PandasCode, PandasCodeArray
19+
20+
try:
21+
base.BaseIndexTests
22+
except AttributeError:
23+
pytest.skip("Not available", allow_module_level=True)
24+
25+
26+
def make_datum():
27+
return Code(str(np.random.rand()))
28+
29+
30+
@pytest.fixture
31+
def dtype():
32+
return PandasCode()
33+
34+
35+
def make_data():
36+
return (
37+
[make_datum() for _ in range(8)]
38+
+ [np.nan]
39+
+ [make_datum() for _ in range(88)]
40+
+ [np.nan]
41+
+ [make_datum(), make_datum()]
42+
)
43+
44+
45+
@pytest.fixture
46+
def data(dtype):
47+
return PandasCodeArray(np.array(make_data(), dtype=object), dtype=dtype)
48+
49+
50+
@pytest.fixture
51+
def data_missing(dtype):
52+
return PandasCodeArray(np.array([np.nan, make_datum()], dtype=object), dtype=dtype)
53+
54+
55+
@pytest.fixture
56+
def data_for_sorting(dtype):
57+
return PandasCodeArray(
58+
np.array([make_datum(), make_datum(), make_datum()], dtype=object), dtype=dtype
59+
)
60+
61+
62+
@pytest.fixture
63+
def data_missing_for_sorting(dtype):
64+
return PandasCodeArray(
65+
np.array([make_datum(), np.nan, make_datum()], dtype=object), dtype=dtype
66+
)
67+
68+
69+
class TestDtype(base.BaseDtypeTests):
70+
def test_is_not_string_type(self, data):
71+
# Override to not return a value, which raises a warning.
72+
super().test_is_not_string_type(data)
73+
74+
def test_is_not_object_type(self, data):
75+
# Override to not return a value, which raises a warning.
76+
super().test_is_not_object_type(data)
77+
78+
79+
class TestInterface(base.BaseInterfaceTests):
80+
def test_contains(self):
81+
# We cannot compare a Code object to an array.
82+
pass
83+
84+
85+
class TestConstructors(base.BaseConstructorsTests):
86+
pass
87+
88+
89+
class TestGetitem(base.BaseGetitemTests):
90+
pass
91+
92+
93+
class TestSetitem(base.BaseSetitemTests):
94+
def test_setitem_frame_2d_values(self):
95+
# Results in passing an integer as a value, which
96+
# cannot be converted to Code type.
97+
pass
98+
99+
100+
class TestIndex(base.BaseIndexTests):
101+
pass
102+
103+
104+
class TestMissing(base.BaseMissingTests):
105+
pass

0 commit comments

Comments
 (0)