Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions bson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from __future__ import annotations

import datetime
import decimal
import itertools
import os
import re
Expand Down Expand Up @@ -858,6 +859,16 @@ def _encode_decimal128(name: bytes, value: Decimal128, dummy0: Any, dummy1: Any)
return b"\x13" + name + value.bid


def _encode_python_decimal(
name: bytes, value: decimal.Decimal, dummy0: Any, opts: CodecOptions[Any]
) -> bytes:
if opts.convert_decimal:
converted = Decimal128(value)
return b"\x13" + name + converted.bid
else:
raise InvalidDocument("decimal.Decimal must be converted to bson.decimal128.Decimal128.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest mentioning the convert_decimal option in this error message.



def _encode_minkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes:
"""Encode bson.min_key.MinKey."""
return b"\xFF" + name
Expand Down Expand Up @@ -885,6 +896,7 @@ def _encode_maxkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes:
str: _encode_text,
tuple: _encode_list,
type(None): _encode_none,
decimal.Decimal: _encode_python_decimal,
uuid.UUID: _encode_uuid,
Binary: _encode_binary,
Int64: _encode_long,
Expand All @@ -908,8 +920,8 @@ def _encode_maxkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes:
if hasattr(_typ, "_type_marker"):
_MARKERS[_typ._type_marker] = _ENCODERS[_typ]


_BUILT_IN_TYPES = tuple(t for t in _ENCODERS)
# Exclude decimal.Decimal since auto-conversion is explicitly opt-in.
_BUILT_IN_TYPES = tuple(t for t in _ENCODERS if t != decimal.Decimal)


def _name_value_to_bson(
Expand Down
74 changes: 55 additions & 19 deletions bson/_cbsonmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,27 @@ extern int cbson_long_long_to_str(long long num, char* str, size_t size) {
return 0;
}

int _check_decimal(PyObject *value) {
static PyObject *decimal_module = NULL;
static PyObject *decimal_class = NULL;

if (decimal_module == NULL) {
decimal_module = PyImport_ImportModule("decimal");
if (decimal_module == NULL) {
PyErr_SetString(PyExc_ImportError, "Failed to import decimal module");
return -1;
}
decimal_class = PyObject_GetAttrString(decimal_module, "Decimal");
if (decimal_class == NULL) {
Py_DECREF(decimal_module);
decimal_module = NULL;
PyErr_SetString(PyExc_AttributeError, "Failed to get Decimal class");
return -1;
}
}
return PyObject_IsInstance(value, decimal_class);
}

static PyObject* _test_long_long_to_str(PyObject* self, PyObject* args) {
// Test extreme values
Py_ssize_t maxNum = PY_SSIZE_T_MAX;
Expand Down Expand Up @@ -791,14 +812,15 @@ int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t

options->unicode_decode_error_handler = NULL;

if (!PyArg_ParseTuple(options_obj, "ObbzOOb",
if (!PyArg_ParseTuple(options_obj, "ObbzOObb",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stared at this for a while but then found: https://docs.python.org/3/c-api/arg.html. So, just confirming that's a format string for the args below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol i was sitting on this one too, googled it, found that same site, was still confused / overwhelemed by words so then i asked mongo-gpt HAHA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup this is a format string.

&options->document_class,
&options->tz_aware,
&options->uuid_rep,
&options->unicode_decode_error_handler,
&options->tzinfo,
&type_registry_obj,
&options->datetime_conversion)) {
&options->datetime_conversion,
&options->convert_decimal)) {
return 0;
}

Expand Down Expand Up @@ -993,6 +1015,26 @@ static int _write_regex_to_buffer(
return 1;
}

static int _write_decimal_128_to_buffer(struct module_state *state, PyObject* value, buffer_t buffer, int type_byte) {
const char* data;
PyObject* pystring = PyObject_GetAttr(value, state->_bid_str);
if (!pystring) {
return 0;
}
data = PyBytes_AsString(pystring);
if (!data) {
Py_DECREF(pystring);
return 0;
}
if (!buffer_write_bytes(buffer, data, 16)) {
Py_DECREF(pystring);
return 0;
}
Py_DECREF(pystring);
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13;
return 1;
}

/* Write a single value to the buffer (also write its type_byte, for which
* space has already been reserved.
*
Expand Down Expand Up @@ -1206,23 +1248,7 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
case 19:
{
/* Decimal128 */
const char* data;
PyObject* pystring = PyObject_GetAttr(value, state->_bid_str);
if (!pystring) {
return 0;
}
data = PyBytes_AsString(pystring);
if (!data) {
Py_DECREF(pystring);
return 0;
}
if (!buffer_write_bytes(buffer, data, 16)) {
Py_DECREF(pystring);
return 0;
}
Py_DECREF(pystring);
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13;
return 1;
return _write_decimal_128_to_buffer(state, value, buffer, type_byte);
}
case 100:
{
Expand Down Expand Up @@ -1436,6 +1462,16 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
in_fallback_call);
Py_DECREF(binary_value);
return result;
} else if (options->convert_decimal && _check_decimal(value)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after case 100 or so and if case 19 wasn't the case, auto-convert ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with how the rest of _write_element_to_buffer works, we check the converison edge cases such as this after the native BSON case statement.

/* Convert decimal.Decimal to Decimal128 */
PyObject* args = PyTuple_New(1);

Py_INCREF(value);
PyTuple_SetItem(args, 0, value);
PyObject* converted = PyObject_CallObject(state->Decimal128, args);
Py_DECREF(args);

return _write_decimal_128_to_buffer(state, converted, buffer, type_byte);
}

/* Try a custom encoder if one is provided and we have not already
Expand Down
1 change: 1 addition & 0 deletions bson/_cbsonmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ typedef struct codec_options_t {
PyObject* tzinfo;
type_registry_t type_registry;
unsigned char datetime_conversion;
unsigned char convert_decimal;
PyObject* options_obj;
unsigned char is_raw_bson;
} codec_options_t;
Expand Down
13 changes: 12 additions & 1 deletion bson/codec_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class _BaseCodecOptions(NamedTuple):
tzinfo: Optional[datetime.tzinfo]
type_registry: TypeRegistry
datetime_conversion: Optional[DatetimeConversion]
convert_decimal: Optional[bool]


if TYPE_CHECKING:
Expand All @@ -253,6 +254,7 @@ class CodecOptions(Tuple[_DocumentType], Generic[_DocumentType]):
tzinfo: Optional[datetime.tzinfo]
type_registry: TypeRegistry
datetime_conversion: Optional[int]
convert_decimal: Optional[bool]

def __new__(
cls: Type[CodecOptions[_DocumentType]],
Expand All @@ -263,6 +265,7 @@ def __new__(
tzinfo: Optional[datetime.tzinfo] = ...,
type_registry: Optional[TypeRegistry] = ...,
datetime_conversion: Optional[int] = ...,
convert_decimal: Optional[bool] = ...,
) -> CodecOptions[_DocumentType]:
...

Expand Down Expand Up @@ -362,6 +365,9 @@ def __init__(self, *args, **kwargs):
return DatetimeMS objects when the underlying datetime is
out-of-range and 'datetime_clamp' to clamp to the minimum and
maximum possible datetimes. Defaults to 'datetime'.
:param convert_decimal: If ``True``, instances of :class:`~decimal.Decimal` will
be automatically converted to :class:`~bson.decimal128.Decimal128` when encoding to BSON.
Defaults to ``False``.

.. versionchanged:: 4.0
The default for `uuid_representation` was changed from
Expand All @@ -388,6 +394,7 @@ def __new__(
tzinfo: Optional[datetime.tzinfo] = None,
type_registry: Optional[TypeRegistry] = None,
datetime_conversion: Optional[DatetimeConversion] = DatetimeConversion.DATETIME,
convert_decimal: Optional[bool] = False,
) -> CodecOptions:
doc_class = document_class or dict
# issubclass can raise TypeError for generic aliases like SON[str, Any].
Expand Down Expand Up @@ -439,6 +446,7 @@ def __new__(
tzinfo,
type_registry,
datetime_conversion,
convert_decimal,
),
)

Expand All @@ -455,14 +463,15 @@ def _arguments_repr(self) -> str:
return (
"document_class={}, tz_aware={!r}, uuid_representation={}, "
"unicode_decode_error_handler={!r}, tzinfo={!r}, "
"type_registry={!r}, datetime_conversion={!s}".format(
"type_registry={!r}, datetime_conversion={!s}, convert_decimal={!s}".format(
document_class_repr,
self.tz_aware,
uuid_rep_repr,
self.unicode_decode_error_handler,
self.tzinfo,
self.type_registry,
self.datetime_conversion,
self.convert_decimal,
)
)

Expand All @@ -477,6 +486,7 @@ def _options_dict(self) -> dict[str, Any]:
"tzinfo": self.tzinfo,
"type_registry": self.type_registry,
"datetime_conversion": self.datetime_conversion,
"convert_decimal": self.convert_decimal,
}

def __repr__(self) -> str:
Expand Down Expand Up @@ -513,6 +523,7 @@ def _parse_codec_options(options: Any) -> CodecOptions[Any]:
"tzinfo",
"type_registry",
"datetime_conversion",
"convert_decimal",
}:
if k == "uuidrepresentation":
kwargs["uuid_representation"] = options[k]
Expand Down
12 changes: 11 additions & 1 deletion test/test_bson.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import array
import collections
import datetime
import decimal
import mmap
import os
import pickle
Expand Down Expand Up @@ -1236,7 +1237,7 @@ def test_codec_options_repr(self):
"unicode_decode_error_handler='strict', "
"tzinfo=None, type_registry=TypeRegistry(type_codecs=[], "
"fallback_encoder=None), "
"datetime_conversion=DatetimeConversion.DATETIME)"
"datetime_conversion=DatetimeConversion.DATETIME, convert_decimal=False)"
)
self.assertEqual(r, repr(CodecOptions()))

Expand Down Expand Up @@ -1406,6 +1407,15 @@ def test_bson_encode_decode(self) -> None:
decoded["new_field"] = 1
self.assertTrue(decoded["_id"].generation_time)

def test_convert_decimal(self):
opts = CodecOptions(convert_decimal=True)
decimal128_doc = {"d": bson.Decimal128("1.0")}
decimal_doc = {"d": decimal.Decimal("1.0")}
decimal_128_encoded = bson.encode(decimal128_doc, codec_options=opts)
decimal_encoded = bson.encode(decimal_doc, codec_options=opts)

self.assertEqual(decimal_128_encoded, decimal_encoded)


class TestDatetimeConversion(unittest.TestCase):
def test_comps(self):
Expand Down
Loading