Skip to content

Commit 2cb34e4

Browse files
committed
PYTHON-1814 Support custom type decoder with distinct
Fix pure python custom type decoding of bson arrays.
1 parent 2f06e8a commit 2cb34e4

File tree

7 files changed

+57
-37
lines changed

7 files changed

+57
-37
lines changed

bson/__init__.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def _get_array(data, position, obj_end, opts, element_name):
219219
append = result.append
220220
index = data.index
221221
getter = _ELEMENT_GETTER
222+
decoder_map = opts.type_registry._decoder_map
222223

223224
while position < end:
224225
element_type = data[position:position + 1]
@@ -229,6 +230,12 @@ def _get_array(data, position, obj_end, opts, element_name):
229230
data, position, obj_end, opts, element_name)
230231
except KeyError:
231232
_raise_unknown_type(element_type, element_name)
233+
234+
if decoder_map:
235+
custom_decoder = decoder_map.get(type(value))
236+
if custom_decoder is not None:
237+
value = custom_decoder(value)
238+
232239
append(value)
233240

234241
if position != end + 1:
@@ -941,17 +948,15 @@ def decode_all(data, codec_options=DEFAULT_CODEC_OPTIONS):
941948

942949

943950
def _decode_selective(rawdoc, fields, codec_options):
944-
doc = codec_options.document_class()
951+
doc = {}
945952
for key, value in iteritems(rawdoc):
946953
if key in fields:
947-
if fields[key] == list:
948-
doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value]
949-
elif fields[key] == dict:
950-
doc[key] = _bson_to_dict(value.raw, codec_options)
954+
if fields[key] == 1:
955+
doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key]
951956
else:
952957
doc[key] = _decode_selective(value, fields[key], codec_options)
953-
continue
954-
doc[key] = value
958+
else:
959+
doc[key] = value
955960
return doc
956961

957962

@@ -970,9 +975,8 @@ def _decode_all_selective(data, codec_options, fields):
970975
- `fields`: Map of document namespaces where data that needs
971976
to be custom decoded lives or None. For example, to custom decode a
972977
list of objects in 'field1.subfield1', the specified value should be
973-
``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list``
974-
if the field contains a single object to custom decode. If ``fields``
975-
is an empty map or None, this method is the same as ``decode_all``.
978+
``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or
979+
None, this method is the same as ``decode_all``.
976980
977981
:Returns:
978982
- `document_list`: Single-member list containing the decoded document.

pymongo/collection.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
_NO_OBJ_ERROR = "No matching object found"
5555
_UJOIN = u"%s.%s"
56-
_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict}
56+
_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1}
5757

5858

5959
class ReturnDocument(object):
@@ -225,6 +225,11 @@ def _command(self, sock_info, command, slave_ok=False,
225225
:class:`~pymongo.collation.Collation`.
226226
- `session` (optional): a
227227
:class:`~pymongo.client_session.ClientSession`.
228+
- `retryable_write` (optional): True if this command is a retryable
229+
write.
230+
- `user_fields` (optional): Response fields that should be decoded
231+
using the TypeDecoders from codec_options, passed to
232+
bson._decode_all_selective.
228233
229234
:Returns:
230235
The result document.
@@ -2313,7 +2318,7 @@ def _aggregate(self, pipeline, cursor_class, first_batch_size, session,
23132318
collation=collation,
23142319
session=session,
23152320
client=self.__database.client,
2316-
user_fields={'cursor': {'firstBatch': list}})
2321+
user_fields={'cursor': {'firstBatch': 1}})
23172322

23182323
if "cursor" in result:
23192324
cursor = result["cursor"]
@@ -2575,7 +2580,7 @@ def group(self, key, condition, initial, reduce, finalize=None, **kwargs):
25752580
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
25762581
return self._command(sock_info, cmd, slave_ok,
25772582
collation=collation,
2578-
user_fields={'retval': list})["retval"]
2583+
user_fields={'retval': 1})["retval"]
25792584

25802585
def rename(self, new_name, session=None, **kwargs):
25812586
"""Rename this collection.
@@ -2680,7 +2685,8 @@ def distinct(self, key, filter=None, session=None, **kwargs):
26802685
return self._command(sock_info, cmd, slave_ok,
26812686
read_concern=self.read_concern,
26822687
collation=collation,
2683-
session=session)["values"]
2688+
session=session,
2689+
user_fields={"values": 1})["values"]
26842690

26852691
def map_reduce(self, map, reduce, out, full_response=False, session=None,
26862692
**kwargs):
@@ -2761,7 +2767,7 @@ def map_reduce(self, map, reduce, out, full_response=False, session=None,
27612767
else:
27622768
write_concern = None
27632769
if inline:
2764-
user_fields = {'results': list}
2770+
user_fields = {'results': 1}
27652771
else:
27662772
user_fields = None
27672773

@@ -2820,7 +2826,7 @@ def inline_map_reduce(self, map, reduce, full_response=False, session=None,
28202826
("map", map),
28212827
("reduce", reduce),
28222828
("out", {"inline": 1})])
2823-
user_fields = {'results': list}
2829+
user_fields = {'results': 1}
28242830
collation = validate_collation_or_none(kwargs.pop('collation', None))
28252831
cmd.update(kwargs)
28262832
with self._socket_for_reads(session) as (sock_info, slave_ok):

pymongo/command_cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def duration(): return datetime.datetime.now() - start
153153
user_fields = None
154154
legacy_response = True
155155
if from_command:
156-
user_fields = {'cursor': {'nextBatch': list}}
156+
user_fields = {'cursor': {'nextBatch': 1}}
157157
legacy_response = False
158158
docs = self._unpack_response(
159159
reply, self.__id, self.__collection.codec_options,

pymongo/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"await_data": 32,
5151
"exhaust": 64,
5252
"partial": 128}
53-
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}}
53+
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
5454

5555

5656
class CursorType(object):

pymongo/network.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
8282
- `parse_write_concern_error`: Whether to parse the ``writeConcernError``
8383
field in the command response.
8484
- `collation`: The collation for this command.
85+
- `compression_ctx`: optional compression Context.
86+
- `use_op_msg`: True if we should use OP_MSG.
87+
- `unacknowledged`: True if this is an unacknowledged command.
88+
- `user_fields` (optional): Response fields that should be decoded
89+
using the TypeDecoders from codec_options, passed to
90+
bson._decode_all_selective.
8591
"""
8692
name = next(iter(spec))
8793
ns = dbname + '.$cmd'

pymongo/pool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,9 @@ def command(self, dbname, spec, slave_ok=False,
534534
- `client`: optional MongoClient for gossipping $clusterTime.
535535
- `retryable_write`: True if this command is a retryable write.
536536
- `publish_events`: Should we publish events for this command?
537+
- `user_fields` (optional): Response fields that should be decoded
538+
using the TypeDecoders from codec_options, passed to
539+
bson._decode_all_selective.
537540
"""
538541
self.validate_session(client, session)
539542
session = _validate_session_write_concern(session, write_concern)

test/test_custom_types.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def __init__(self, value):
8181
def __eq__(self, other):
8282
if isinstance(other, type(self)):
8383
return self.value == other.value
84-
return self.value == other
84+
# Does not compare equal to integers.
85+
return False
8586

8687

8788
class UndecipherableIntDecoder(TypeDecoder):
@@ -115,11 +116,17 @@ def transform_bson(self, value):
115116

116117

117118
class CustomBSONTypeTests(object):
118-
def test_encode_decode_roundtrip(self):
119-
document = {'average': Decimal('56.47')}
120-
bsonbytes = BSON().encode(document, codec_options=self.codecopts)
119+
def roundtrip(self, doc):
120+
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
121121
rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts)
122-
self.assertEqual(document, rt_document)
122+
self.assertEqual(doc, rt_document)
123+
124+
def test_encode_decode_roundtrip(self):
125+
self.roundtrip({'average': Decimal('56.47')})
126+
self.roundtrip({'average': {'b': Decimal('56.47')}})
127+
self.roundtrip({'average': [Decimal('56.47')]})
128+
self.roundtrip({'average': [[Decimal('56.47')]]})
129+
self.roundtrip({'average': [{'b': Decimal('56.47')}]})
123130

124131
def test_decode_all(self):
125132
documents = []
@@ -585,24 +592,18 @@ def test_aggregate_w_custom_type_decoder(self):
585592
self.assertIsInstance(res['total_qty'], UndecipherableInt64Type)
586593
self.assertEqual(res['total_qty'].value, 20)
587594

588-
# collection.distinct does not support custom type decoding
589595
def test_distinct_w_custom_type(self):
590596
self.db.drop_collection("test")
591597

592598
test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS)
593-
test.insert_many([
594-
{"a": UndecipherableInt64Type(1)},
595-
{"a": UndecipherableInt64Type(2)},
596-
{"a": UndecipherableInt64Type(2)},
597-
{"a": UndecipherableInt64Type(2)},
598-
{"a": UndecipherableInt64Type(3)}])
599-
600-
distinct = test.distinct("a")
601-
distinct.sort()
602-
603-
self.assertEqual([
604-
UndecipherableInt64Type(1), UndecipherableInt64Type(2),
605-
UndecipherableInt64Type(3)], distinct)
599+
values = [
600+
UndecipherableInt64Type(1),
601+
UndecipherableInt64Type(2),
602+
UndecipherableInt64Type(3),
603+
{"b": UndecipherableInt64Type(3)}]
604+
test.insert_many({"a": val} for val in values)
605+
606+
self.assertEqual(values, test.distinct("a"))
606607

607608
def test_map_reduce_w_custom_type(self):
608609
test = self.db.get_collection(

0 commit comments

Comments
 (0)