Skip to content

Commit 29332c3

Browse files
committed
Add json_util.load for extended JSON
1 parent 72e1c2e commit 29332c3

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

bson/json_util.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,29 @@ def loads(s: Union[str, bytes, bytearray], *args: Any, **kwargs: Any) -> Any:
507507
return json.loads(s, *args, **kwargs)
508508

509509

510+
def load(fp: Any, *args: Any, **kwargs: Any) -> Any:
511+
"""Helper function that wraps :func:`json.load`.
512+
513+
Automatically passes the object_hook for BSON type conversion.
514+
515+
Raises ``TypeError``, ``ValueError``, ``KeyError``, or
516+
:exc:`~bson.errors.InvalidId` on invalid MongoDB Extended JSON.
517+
518+
:param json_options: A :class:`JSONOptions` instance used to modify the
519+
decoding of MongoDB Extended JSON types. Defaults to
520+
:const:`DEFAULT_JSON_OPTIONS`.
521+
522+
.. versionadded:: 4.12
523+
"""
524+
json_options = kwargs.pop("json_options", DEFAULT_JSON_OPTIONS)
525+
# Execution time optimization if json_options.document_class is dict
526+
if json_options.document_class is dict:
527+
kwargs["object_hook"] = lambda obj: object_hook(obj, json_options)
528+
else:
529+
kwargs["object_pairs_hook"] = lambda pairs: object_pairs_hook(pairs, json_options)
530+
return json.load(fp, *args, **kwargs)
531+
532+
510533
def _json_convert(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any:
511534
"""Recursive helper method that converts BSON types so they can be
512535
converted into json.

test/test_bson_binary_vector.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,12 @@
2121
from pathlib import Path
2222
from test import unittest
2323

24-
from bson import decode, encode
24+
from bson import decode, encode, json_util
2525
from bson.binary import Binary, BinaryVectorDtype
2626

2727
_TEST_PATH = Path(__file__).parent / "bson_binary_vector"
2828

2929

30-
def convert_extended_json(vector) -> float:
31-
if isinstance(vector, dict) and "$numberDouble" in vector:
32-
if vector["$numberDouble"] == "Infinity":
33-
return float("inf")
34-
elif vector["$numberDouble"] == "-Infinity":
35-
return float("-inf")
36-
return float(vector)
37-
38-
3930
class TestBSONBinaryVector(unittest.TestCase):
4031
"""Runs Binary Vector subtype tests.
4132
@@ -71,9 +62,6 @@ def run_test(self):
7162
cB_exp = binascii.unhexlify(canonical_bson_exp.encode("utf8"))
7263
decoded_doc = decode(cB_exp)
7364
binary_obs = decoded_doc[test_key]
74-
# Handle special extended JSON cases like 'Infinity'
75-
if dtype_exp in [BinaryVectorDtype.FLOAT32]:
76-
vector_exp = [convert_extended_json(x) for x in vector_exp]
7765

7866
# Test round-tripping canonical bson.
7967
self.assertEqual(encode(decoded_doc), cB_exp, description)
@@ -113,7 +101,7 @@ def run_test(self):
113101
def create_tests():
114102
for filename in _TEST_PATH.glob("*.json"):
115103
with codecs.open(str(filename), encoding="utf-8") as test_file:
116-
test_method = create_test(json.load(test_file))
104+
test_method = create_test(json_util.load(test_file))
117105
setattr(TestBSONBinaryVector, "test_" + filename.stem, test_method)
118106

119107

0 commit comments

Comments
 (0)