Skip to content

Commit b24288e

Browse files
committed
PYTHON-5395 - Add convert_decimal to CodecOptions
1 parent bbb6f88 commit b24288e

File tree

5 files changed

+90
-20
lines changed

5 files changed

+90
-20
lines changed

bson/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from __future__ import annotations
7373

7474
import datetime
75+
import decimal
7576
import itertools
7677
import os
7778
import re
@@ -858,6 +859,16 @@ def _encode_decimal128(name: bytes, value: Decimal128, dummy0: Any, dummy1: Any)
858859
return b"\x13" + name + value.bid
859860

860861

862+
def _encode_python_decimal(
863+
name: bytes, value: decimal.Decimal, dummy0: Any, opts: CodecOptions[Any]
864+
) -> bytes:
865+
if opts.convert_decimal:
866+
converted = Decimal128(value)
867+
return b"\x13" + name + converted.bid
868+
else:
869+
raise InvalidDocument("decimal.Decimal must be converted to bson.decimal128.Decimal128.")
870+
871+
861872
def _encode_minkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes:
862873
"""Encode bson.min_key.MinKey."""
863874
return b"\xFF" + name
@@ -885,6 +896,7 @@ def _encode_maxkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes:
885896
str: _encode_text,
886897
tuple: _encode_list,
887898
type(None): _encode_none,
899+
decimal.Decimal: _encode_python_decimal,
888900
uuid.UUID: _encode_uuid,
889901
Binary: _encode_binary,
890902
Int64: _encode_long,

bson/_cbsonmodule.c

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,27 @@ extern int cbson_long_long_to_str(long long num, char* str, size_t size) {
159159
return 0;
160160
}
161161

162+
int _check_decimal(PyObject *value) {
163+
static PyObject *decimal_module = NULL;
164+
static PyObject *decimal_class = NULL;
165+
166+
if (decimal_module == NULL) {
167+
decimal_module = PyImport_ImportModule("decimal");
168+
if (decimal_module == NULL) {
169+
PyErr_SetString(PyExc_ImportError, "Failed to import decimal module");
170+
return -1;
171+
}
172+
decimal_class = PyObject_GetAttrString(decimal_module, "Decimal");
173+
if (decimal_class == NULL) {
174+
Py_DECREF(decimal_module);
175+
decimal_module = NULL;
176+
PyErr_SetString(PyExc_AttributeError, "Failed to get Decimal class");
177+
return -1;
178+
}
179+
}
180+
return PyObject_IsInstance(value, decimal_class);
181+
}
182+
162183
static PyObject* _test_long_long_to_str(PyObject* self, PyObject* args) {
163184
// Test extreme values
164185
Py_ssize_t maxNum = PY_SSIZE_T_MAX;
@@ -791,14 +812,15 @@ int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t
791812

792813
options->unicode_decode_error_handler = NULL;
793814

794-
if (!PyArg_ParseTuple(options_obj, "ObbzOOb",
815+
if (!PyArg_ParseTuple(options_obj, "ObbzOObb",
795816
&options->document_class,
796817
&options->tz_aware,
797818
&options->uuid_rep,
798819
&options->unicode_decode_error_handler,
799820
&options->tzinfo,
800821
&type_registry_obj,
801-
&options->datetime_conversion)) {
822+
&options->datetime_conversion,
823+
&options->convert_decimal)) {
802824
return 0;
803825
}
804826

@@ -993,6 +1015,26 @@ static int _write_regex_to_buffer(
9931015
return 1;
9941016
}
9951017

1018+
static int _write_decimal_128_to_buffer(struct module_state *state, PyObject* value, buffer_t buffer, int type_byte) {
1019+
const char* data;
1020+
PyObject* pystring = PyObject_GetAttr(value, state->_bid_str);
1021+
if (!pystring) {
1022+
return 0;
1023+
}
1024+
data = PyBytes_AsString(pystring);
1025+
if (!data) {
1026+
Py_DECREF(pystring);
1027+
return 0;
1028+
}
1029+
if (!buffer_write_bytes(buffer, data, 16)) {
1030+
Py_DECREF(pystring);
1031+
return 0;
1032+
}
1033+
Py_DECREF(pystring);
1034+
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13;
1035+
return 1;
1036+
}
1037+
9961038
/* Write a single value to the buffer (also write its type_byte, for which
9971039
* space has already been reserved.
9981040
*
@@ -1206,23 +1248,7 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
12061248
case 19:
12071249
{
12081250
/* Decimal128 */
1209-
const char* data;
1210-
PyObject* pystring = PyObject_GetAttr(value, state->_bid_str);
1211-
if (!pystring) {
1212-
return 0;
1213-
}
1214-
data = PyBytes_AsString(pystring);
1215-
if (!data) {
1216-
Py_DECREF(pystring);
1217-
return 0;
1218-
}
1219-
if (!buffer_write_bytes(buffer, data, 16)) {
1220-
Py_DECREF(pystring);
1221-
return 0;
1222-
}
1223-
Py_DECREF(pystring);
1224-
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13;
1225-
return 1;
1251+
return _write_decimal_128_to_buffer(state, value, buffer, type_byte);
12261252
}
12271253
case 100:
12281254
{
@@ -1436,6 +1462,16 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
14361462
in_fallback_call);
14371463
Py_DECREF(binary_value);
14381464
return result;
1465+
} else if (options->convert_decimal && _check_decimal(value)) {
1466+
/* Convert decimal.Decimal to Decimal128 */
1467+
PyObject* args = PyTuple_New(1);
1468+
1469+
Py_INCREF(value);
1470+
PyTuple_SetItem(args, 0, value);
1471+
PyObject* converted = PyObject_CallObject(state->Decimal128, args);
1472+
Py_DECREF(args);
1473+
1474+
return _write_decimal_128_to_buffer(state, converted, buffer, type_byte);
14391475
}
14401476

14411477
/* Try a custom encoder if one is provided and we have not already

bson/_cbsonmodule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ typedef struct codec_options_t {
7070
PyObject* tzinfo;
7171
type_registry_t type_registry;
7272
unsigned char datetime_conversion;
73+
unsigned char convert_decimal;
7374
PyObject* options_obj;
7475
unsigned char is_raw_bson;
7576
} codec_options_t;

bson/codec_options.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ class _BaseCodecOptions(NamedTuple):
241241
tzinfo: Optional[datetime.tzinfo]
242242
type_registry: TypeRegistry
243243
datetime_conversion: Optional[DatetimeConversion]
244+
convert_decimal: Optional[bool]
244245

245246

246247
if TYPE_CHECKING:
@@ -253,6 +254,7 @@ class CodecOptions(Tuple[_DocumentType], Generic[_DocumentType]):
253254
tzinfo: Optional[datetime.tzinfo]
254255
type_registry: TypeRegistry
255256
datetime_conversion: Optional[int]
257+
convert_decimal: Optional[bool]
256258

257259
def __new__(
258260
cls: Type[CodecOptions[_DocumentType]],
@@ -263,6 +265,7 @@ def __new__(
263265
tzinfo: Optional[datetime.tzinfo] = ...,
264266
type_registry: Optional[TypeRegistry] = ...,
265267
datetime_conversion: Optional[int] = ...,
268+
convert_decimal: Optional[bool] = ...,
266269
) -> CodecOptions[_DocumentType]:
267270
...
268271

@@ -362,6 +365,9 @@ def __init__(self, *args, **kwargs):
362365
return DatetimeMS objects when the underlying datetime is
363366
out-of-range and 'datetime_clamp' to clamp to the minimum and
364367
maximum possible datetimes. Defaults to 'datetime'.
368+
:param convert_decimal: If ``True``, instances of :class:`~decimal.Decimal` will
369+
be automatically converted to :class:`~bson.decimal128.Decimal128` when encoding to BSON.
370+
Defaults to ``False``.
365371
366372
.. versionchanged:: 4.0
367373
The default for `uuid_representation` was changed from
@@ -388,6 +394,7 @@ def __new__(
388394
tzinfo: Optional[datetime.tzinfo] = None,
389395
type_registry: Optional[TypeRegistry] = None,
390396
datetime_conversion: Optional[DatetimeConversion] = DatetimeConversion.DATETIME,
397+
convert_decimal: Optional[bool] = False,
391398
) -> CodecOptions:
392399
doc_class = document_class or dict
393400
# issubclass can raise TypeError for generic aliases like SON[str, Any].
@@ -439,6 +446,7 @@ def __new__(
439446
tzinfo,
440447
type_registry,
441448
datetime_conversion,
449+
convert_decimal,
442450
),
443451
)
444452

@@ -455,14 +463,15 @@ def _arguments_repr(self) -> str:
455463
return (
456464
"document_class={}, tz_aware={!r}, uuid_representation={}, "
457465
"unicode_decode_error_handler={!r}, tzinfo={!r}, "
458-
"type_registry={!r}, datetime_conversion={!s}".format(
466+
"type_registry={!r}, datetime_conversion={!s}, convert_decimal={!s}".format(
459467
document_class_repr,
460468
self.tz_aware,
461469
uuid_rep_repr,
462470
self.unicode_decode_error_handler,
463471
self.tzinfo,
464472
self.type_registry,
465473
self.datetime_conversion,
474+
self.convert_decimal,
466475
)
467476
)
468477

@@ -477,6 +486,7 @@ def _options_dict(self) -> dict[str, Any]:
477486
"tzinfo": self.tzinfo,
478487
"type_registry": self.type_registry,
479488
"datetime_conversion": self.datetime_conversion,
489+
"convert_decimal": self.convert_decimal,
480490
}
481491

482492
def __repr__(self) -> str:
@@ -513,6 +523,7 @@ def _parse_codec_options(options: Any) -> CodecOptions[Any]:
513523
"tzinfo",
514524
"type_registry",
515525
"datetime_conversion",
526+
"convert_decimal",
516527
}:
517528
if k == "uuidrepresentation":
518529
kwargs["uuid_representation"] = options[k]

test/test_bson.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import array
2020
import collections
2121
import datetime
22+
import decimal
2223
import mmap
2324
import os
2425
import pickle
@@ -1406,6 +1407,15 @@ def test_bson_encode_decode(self) -> None:
14061407
decoded["new_field"] = 1
14071408
self.assertTrue(decoded["_id"].generation_time)
14081409

1410+
def test_convert_decimal(self):
1411+
opts = CodecOptions(convert_decimal=True)
1412+
decimal128_doc = {"d": bson.Decimal128("1.0")}
1413+
decimal_doc = {"d": decimal.Decimal("1.0")}
1414+
decimal_128_encoded = bson.encode(decimal128_doc, codec_options=opts)
1415+
decimal_encoded = bson.encode(decimal_doc, codec_options=opts)
1416+
1417+
self.assertEqual(decimal_128_encoded, decimal_encoded)
1418+
14091419

14101420
class TestDatetimeConversion(unittest.TestCase):
14111421
def test_comps(self):

0 commit comments

Comments
 (0)