Skip to content

Commit 599e2d7

Browse files
committed
PYTHON-1782 Allow MongoClient to be initialized with type_registry
1 parent cda0b71 commit 599e2d7

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

pymongo/common.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from bson import SON
2222
from bson.binary import (STANDARD, PYTHON_LEGACY,
2323
JAVA_LEGACY, CSHARP_LEGACY)
24-
from bson.codec_options import CodecOptions
24+
from bson.codec_options import CodecOptions, TypeRegistry
2525
from bson.py3compat import abc, integer_types, iteritems, string_type
2626
from bson.raw_bson import RawBSONDocument
2727
from pymongo.auth import MECHANISMS
@@ -423,6 +423,14 @@ def validate_document_class(option, value):
423423
return value
424424

425425

426+
def validate_type_registry(option, value):
427+
"""Validate the type_registry option."""
428+
if value is not None and not isinstance(value, TypeRegistry):
429+
raise TypeError("%s must be an instance of %s" % (
430+
option, TypeRegistry))
431+
return value
432+
433+
426434
def validate_list(option, value):
427435
"""Validates that 'value' is a list."""
428436
if not isinstance(value, list):
@@ -600,6 +608,7 @@ def validate_tzinfo(dummy, value):
600608
# values for those options.
601609
KW_VALIDATORS = {
602610
'document_class': validate_document_class,
611+
'type_registry': validate_type_registry,
603612
'read_preference': validate_read_preference,
604613
'event_listeners': _validate_event_listeners,
605614
'tzinfo': validate_tzinfo,

pymongo/mongo_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from collections import defaultdict
4141

42-
from bson.codec_options import DEFAULT_CODEC_OPTIONS
42+
from bson.codec_options import DEFAULT_CODEC_OPTIONS, TypeRegistry
4343
from bson.py3compat import (integer_types,
4444
string_type)
4545
from bson.son import SON
@@ -98,6 +98,7 @@ def __init__(
9898
host=None,
9999
port=None,
100100
document_class=dict,
101+
type_registry=None,
101102
tz_aware=None,
102103
connect=None,
103104
**kwargs):
@@ -190,6 +191,9 @@ def __init__(
190191
- `port` (optional): port number on which to connect
191192
- `document_class` (optional): default class to use for
192193
documents returned from queries on this client
194+
- `type_registry` (optional): instance of
195+
:class:`~bson.codec_options.TypeRegistry` to enable encoding
196+
and decoding of custom types.
193197
- `tz_aware` (optional): if ``True``,
194198
:class:`~datetime.datetime` instances returned as values
195199
in a document by this :class:`MongoClient` will be timezone
@@ -454,6 +458,7 @@ def __init__(
454458
455459
.. versionchanged:: 3.8
456460
Added the ``server_selector`` keyword argument.
461+
Added the ``type_registry`` keyword argument.
457462
458463
.. versionchanged:: 3.7
459464
Added the ``driver`` keyword argument.
@@ -564,6 +569,8 @@ def __init__(
564569

565570
keyword_opts = kwargs
566571
keyword_opts['document_class'] = document_class
572+
if type_registry is not None:
573+
keyword_opts['type_registry'] = type_registry
567574
if tz_aware is None:
568575
tz_aware = opts.get('tz_aware', False)
569576
if connect is None:

test/test_client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
sys.path[0:0] = [""]
3030

3131
from bson import BSON
32-
from bson.codec_options import CodecOptions
32+
from bson.codec_options import CodecOptions, TypeEncoder, TypeRegistry
3333
from bson.py3compat import thread
3434
from bson.son import SON
3535
from bson.tz_util import utc
@@ -248,14 +248,21 @@ def test_metadata(self):
248248
self.assertEqual(options.pool_options.metadata, metadata)
249249

250250
def test_kwargs_codec_options(self):
251+
class FloatAsIntEncoder(TypeEncoder):
252+
python_type = float
253+
def transform_python(self, value):
254+
return int(value)
255+
251256
# Ensure codec options are passed in correctly
252257
document_class = SON
258+
type_registry = TypeRegistry([FloatAsIntEncoder()])
253259
tz_aware = True
254260
uuid_representation_label = 'javaLegacy'
255261
unicode_decode_error_handler = 'ignore'
256262
tzinfo = utc
257263
c = MongoClient(
258264
document_class=document_class,
265+
type_registry=type_registry,
259266
tz_aware=tz_aware,
260267
uuidrepresentation=uuid_representation_label,
261268
unicode_decode_error_handler=unicode_decode_error_handler,
@@ -264,6 +271,7 @@ def test_kwargs_codec_options(self):
264271
)
265272

266273
self.assertEqual(c.codec_options.document_class, document_class)
274+
self.assertEqual(c.codec_options.type_registry, type_registry)
267275
self.assertEqual(c.codec_options.tz_aware, tz_aware)
268276
self.assertEqual(
269277
c.codec_options.uuid_representation,

0 commit comments

Comments
 (0)