Skip to content

Commit caf1cc5

Browse files
authored
INTPYTHON-538 Add support for PyArrow Decimal128 type (#278)
1 parent 9bbbed7 commit caf1cc5

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import warnings
15+
from decimal import Decimal
1516

1617
import numpy as np
1718
import pandas as pd
@@ -25,6 +26,7 @@
2526
import pymongo.errors
2627
from bson import encode
2728
from bson.codec_options import TypeEncoder, TypeRegistry
29+
from bson.decimal128 import Decimal128
2830
from bson.raw_bson import RawBSONDocument
2931
from numpy import ndarray
3032
from pyarrow import Schema as ArrowSchema
@@ -416,6 +418,18 @@ def transform_python(self, _):
416418
return
417419

418420

421+
class _DecimalCodec(TypeEncoder):
422+
"""A custom type codec for Decimal objects."""
423+
424+
@property
425+
def python_type(self):
426+
return Decimal
427+
428+
def transform_python(self, value):
429+
"""Transform an Decimal object into a BSON Decimal128 object"""
430+
return Decimal128(value)
431+
432+
419433
def write(collection, tabular, *, exclude_none: bool = False):
420434
"""Write data from `tabular` into the given MongoDB `collection`.
421435
@@ -469,9 +483,9 @@ def write(collection, tabular, *, exclude_none: bool = False):
469483

470484
tabular_gen = _tabular_generator(tabular, exclude_none=exclude_none)
471485

472-
# Handle Pandas NA objects.
486+
# Add handling for special case types.
473487
codec_options = collection.codec_options
474-
type_registry = TypeRegistry([_PandasNACodec()])
488+
type_registry = TypeRegistry([_PandasNACodec(), _DecimalCodec()])
475489
codec_options = codec_options.with_options(type_registry=type_registry)
476490

477491
while cur_offset < tab_size:

bindings/python/pymongoarrow/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def get_numpy_type(type):
277277
_atypes.is_date64: _BsonArrowTypes.date64.value,
278278
_atypes.is_large_string: _BsonArrowTypes.string.value,
279279
_atypes.is_large_list: _BsonArrowTypes.array.value,
280+
_atypes.is_decimal128: _BsonArrowTypes.decimal128.value,
280281
}
281282

282283

bindings/python/test/test_arrow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,17 @@ def test_exclude_none(self):
966966
col_data = list(self.coll.find({}))
967967
assert "b" not in col_data[2]
968968

969+
def test_decimal128(self):
970+
import decimal
971+
972+
a = decimal.Decimal("123.45")
973+
arr = pa.array([a], pa.decimal128(5, 2))
974+
data = Table.from_arrays([arr], names=["data"])
975+
self.coll.drop()
976+
write(self.coll, data)
977+
coll_data = list(self.coll.find({}))
978+
assert coll_data[0]["data"] == Decimal128(a)
979+
969980

970981
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
971982
def run_find(self, *args, **kwargs):

0 commit comments

Comments
 (0)