diff --git a/bson/__init__.py b/bson/__init__.py index b655e30c2c..ef220d1ace 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -72,6 +72,7 @@ from __future__ import annotations import datetime +import decimal import itertools import os import re @@ -858,6 +859,16 @@ def _encode_decimal128(name: bytes, value: Decimal128, dummy0: Any, dummy1: Any) return b"\x13" + name + value.bid +def _encode_python_decimal( + name: bytes, value: decimal.Decimal, dummy0: Any, opts: CodecOptions[Any] +) -> bytes: + if opts.convert_decimal: + converted = Decimal128(value) + return b"\x13" + name + converted.bid + else: + raise InvalidDocument("decimal.Decimal must be converted to bson.decimal128.Decimal128.") + + def _encode_minkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes: """Encode bson.min_key.MinKey.""" return b"\xFF" + name @@ -937,6 +948,9 @@ def _name_value_to_bson( # Give the fallback_encoder a chance was_integer_overflow = True + if opts.convert_decimal and type(value) == decimal.Decimal: + return _encode_python_decimal(name, value, check_keys, opts) + # Second, fall back to trying _type_marker. This has to be done # before the loop below since users could subclass one of our # custom types that subclasses a python built-in (e.g. Binary) diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index be91e41734..d7befee6b4 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -159,6 +159,27 @@ extern int cbson_long_long_to_str(long long num, char* str, size_t size) { return 0; } +int _check_decimal(PyObject *value) { + static PyObject *decimal_module = NULL; + static PyObject *decimal_class = NULL; + + if (decimal_module == NULL) { + decimal_module = PyImport_ImportModule("decimal"); + if (decimal_module == NULL) { + PyErr_SetString(PyExc_ImportError, "Failed to import decimal module"); + return -1; + } + decimal_class = PyObject_GetAttrString(decimal_module, "Decimal"); + if (decimal_class == NULL) { + Py_DECREF(decimal_module); + decimal_module = NULL; + PyErr_SetString(PyExc_AttributeError, "Failed to get Decimal class"); + return -1; + } + } + return PyObject_IsInstance(value, decimal_class); +} + static PyObject* _test_long_long_to_str(PyObject* self, PyObject* args) { // Test extreme values Py_ssize_t maxNum = PY_SSIZE_T_MAX; @@ -791,14 +812,15 @@ int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t options->unicode_decode_error_handler = NULL; - if (!PyArg_ParseTuple(options_obj, "ObbzOOb", + if (!PyArg_ParseTuple(options_obj, "ObbzOObb", &options->document_class, &options->tz_aware, &options->uuid_rep, &options->unicode_decode_error_handler, &options->tzinfo, &type_registry_obj, - &options->datetime_conversion)) { + &options->datetime_conversion, + &options->convert_decimal)) { return 0; } @@ -993,6 +1015,26 @@ static int _write_regex_to_buffer( return 1; } +static int _write_decimal_128_to_buffer(struct module_state *state, PyObject* value, buffer_t buffer, int type_byte) { + const char* data; + PyObject* pystring = PyObject_GetAttr(value, state->_bid_str); + if (!pystring) { + return 0; + } + data = PyBytes_AsString(pystring); + if (!data) { + Py_DECREF(pystring); + return 0; + } + if (!buffer_write_bytes(buffer, data, 16)) { + Py_DECREF(pystring); + return 0; + } + Py_DECREF(pystring); + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13; + return 1; +} + /* Write a single value to the buffer (also write its type_byte, for which * space has already been reserved. * @@ -1206,23 +1248,7 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, case 19: { /* Decimal128 */ - const char* data; - PyObject* pystring = PyObject_GetAttr(value, state->_bid_str); - if (!pystring) { - return 0; - } - data = PyBytes_AsString(pystring); - if (!data) { - Py_DECREF(pystring); - return 0; - } - if (!buffer_write_bytes(buffer, data, 16)) { - Py_DECREF(pystring); - return 0; - } - Py_DECREF(pystring); - *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13; - return 1; + return _write_decimal_128_to_buffer(state, value, buffer, type_byte); } case 100: { @@ -1436,6 +1462,16 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, in_fallback_call); Py_DECREF(binary_value); return result; + } else if (options->convert_decimal && _check_decimal(value)) { + /* Convert decimal.Decimal to Decimal128 */ + PyObject* args = PyTuple_New(1); + + Py_INCREF(value); + PyTuple_SetItem(args, 0, value); + PyObject* converted = PyObject_CallObject(state->Decimal128, args); + Py_DECREF(args); + + return _write_decimal_128_to_buffer(state, converted, buffer, type_byte); } /* Try a custom encoder if one is provided and we have not already diff --git a/bson/_cbsonmodule.h b/bson/_cbsonmodule.h index 3be2b74427..5a333a2b8f 100644 --- a/bson/_cbsonmodule.h +++ b/bson/_cbsonmodule.h @@ -70,6 +70,7 @@ typedef struct codec_options_t { PyObject* tzinfo; type_registry_t type_registry; unsigned char datetime_conversion; + unsigned char convert_decimal; PyObject* options_obj; unsigned char is_raw_bson; } codec_options_t; diff --git a/bson/codec_options.py b/bson/codec_options.py index add5416a5b..a9c42a200d 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -241,6 +241,7 @@ class _BaseCodecOptions(NamedTuple): tzinfo: Optional[datetime.tzinfo] type_registry: TypeRegistry datetime_conversion: Optional[DatetimeConversion] + convert_decimal: Optional[bool] if TYPE_CHECKING: @@ -253,6 +254,7 @@ class CodecOptions(Tuple[_DocumentType], Generic[_DocumentType]): tzinfo: Optional[datetime.tzinfo] type_registry: TypeRegistry datetime_conversion: Optional[int] + convert_decimal: Optional[bool] def __new__( cls: Type[CodecOptions[_DocumentType]], @@ -263,6 +265,7 @@ def __new__( tzinfo: Optional[datetime.tzinfo] = ..., type_registry: Optional[TypeRegistry] = ..., datetime_conversion: Optional[int] = ..., + convert_decimal: Optional[bool] = ..., ) -> CodecOptions[_DocumentType]: ... @@ -362,6 +365,12 @@ def __init__(self, *args, **kwargs): return DatetimeMS objects when the underlying datetime is out-of-range and 'datetime_clamp' to clamp to the minimum and maximum possible datetimes. Defaults to 'datetime'. + :param convert_decimal: If ``True``, instances of :class:`~decimal.Decimal` will + be automatically converted to :class:`~bson.decimal128.Decimal128` when encoding to BSON. + Defaults to ``False``. + + .. versionadded:: 4.15 + `convert_decimal` attribute. .. versionchanged:: 4.0 The default for `uuid_representation` was changed from @@ -388,6 +397,7 @@ def __new__( tzinfo: Optional[datetime.tzinfo] = None, type_registry: Optional[TypeRegistry] = None, datetime_conversion: Optional[DatetimeConversion] = DatetimeConversion.DATETIME, + convert_decimal: Optional[bool] = False, ) -> CodecOptions: doc_class = document_class or dict # issubclass can raise TypeError for generic aliases like SON[str, Any]. @@ -439,6 +449,7 @@ def __new__( tzinfo, type_registry, datetime_conversion, + convert_decimal, ), ) @@ -455,7 +466,7 @@ def _arguments_repr(self) -> str: return ( "document_class={}, tz_aware={!r}, uuid_representation={}, " "unicode_decode_error_handler={!r}, tzinfo={!r}, " - "type_registry={!r}, datetime_conversion={!s}".format( + "type_registry={!r}, datetime_conversion={!s}, convert_decimal={!s}".format( document_class_repr, self.tz_aware, uuid_rep_repr, @@ -463,6 +474,7 @@ def _arguments_repr(self) -> str: self.tzinfo, self.type_registry, self.datetime_conversion, + self.convert_decimal, ) ) @@ -477,6 +489,7 @@ def _options_dict(self) -> dict[str, Any]: "tzinfo": self.tzinfo, "type_registry": self.type_registry, "datetime_conversion": self.datetime_conversion, + "convert_decimal": self.convert_decimal, } def __repr__(self) -> str: @@ -513,6 +526,7 @@ def _parse_codec_options(options: Any) -> CodecOptions[Any]: "tzinfo", "type_registry", "datetime_conversion", + "convert_decimal", }: if k == "uuidrepresentation": kwargs["uuid_representation"] = options[k] diff --git a/test/test_bson.py b/test/test_bson.py index e4cf85c46c..fd25a9351d 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -19,6 +19,7 @@ import array import collections import datetime +import decimal import mmap import os import pickle @@ -1236,7 +1237,7 @@ def test_codec_options_repr(self): "unicode_decode_error_handler='strict', " "tzinfo=None, type_registry=TypeRegistry(type_codecs=[], " "fallback_encoder=None), " - "datetime_conversion=DatetimeConversion.DATETIME)" + "datetime_conversion=DatetimeConversion.DATETIME, convert_decimal=False)" ) self.assertEqual(r, repr(CodecOptions())) @@ -1406,6 +1407,15 @@ def test_bson_encode_decode(self) -> None: decoded["new_field"] = 1 self.assertTrue(decoded["_id"].generation_time) + def test_convert_decimal(self): + opts = CodecOptions(convert_decimal=True) + decimal128_doc = {"d": bson.Decimal128("1.0")} + decimal_doc = {"d": decimal.Decimal("1.0")} + decimal_128_encoded = bson.encode(decimal128_doc, codec_options=opts) + decimal_encoded = bson.encode(decimal_doc, codec_options=opts) + + self.assertEqual(decimal_128_encoded, decimal_encoded) + class TestDatetimeConversion(unittest.TestCase): def test_comps(self):