Skip to content

Commit 2f2fe9d

Browse files
committed
PYTHON-1818 TypeCodec support for ChangeStreams
1 parent fbb56a2 commit 2f2fe9d

File tree

4 files changed

+207
-7
lines changed

4 files changed

+207
-7
lines changed

bson/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,13 @@ def decode_all(data, codec_options=DEFAULT_CODEC_OPTIONS):
948948

949949

950950
def _decode_selective(rawdoc, fields, codec_options):
951-
doc = {}
951+
if _raw_document_class(codec_options.document_class):
952+
# If document_class is RawBSONDocument, use vanilla dictionary for
953+
# decoding command response.
954+
doc = {}
955+
else:
956+
# Else, use the specified document_class.
957+
doc = codec_options.document_class()
952958
for key, value in iteritems(rawdoc):
953959
if key in fields:
954960
if fields[key] == 1:

bson/raw_bson.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from bson.py3compat import abc, iteritems
2020
from bson.codec_options import (
2121
DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER)
22+
from bson.son import SON
2223

2324

2425
class RawBSONDocument(abc.Mapping):
@@ -93,8 +94,9 @@ def __inflated(self):
9394
if self.__inflated_doc is None:
9495
# We already validated the object's size when this document was
9596
# created, so no need to do that again.
97+
# Use SON to preserve ordering of elements.
9698
self.__inflated_doc = _elements_to_dict(
97-
self.__raw, 4, len(self.__raw)-1, self.__codec_options, {})
99+
self.__raw, 4, len(self.__raw)-1, self.__codec_options, SON())
98100
return self.__inflated_doc
99101

100102
def __getitem__(self, item):

pymongo/change_stream.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import copy
1818

19+
from bson import _bson_to_dict
20+
from bson.raw_bson import RawBSONDocument
1921
from bson.son import SON
2022

2123
from pymongo import common
@@ -60,7 +62,16 @@ def __init__(self, target, pipeline, full_document, resume_after,
6062
validate_collation_or_none(collation)
6163
common.validate_non_negative_integer_or_none("batchSize", batch_size)
6264

63-
self._target = target
65+
self._decode_custom = False
66+
self._orig_codec_options = target.codec_options
67+
if target.codec_options.type_registry._decoder_map:
68+
self._decode_custom = True
69+
self._target = target.with_options(
70+
codec_options=target.codec_options.with_options(
71+
document_class=RawBSONDocument, type_registry=None))
72+
else:
73+
self._target = target
74+
6475
self._pipeline = copy.deepcopy(pipeline)
6576
self._full_document = full_document
6677
self._resume_token = copy.deepcopy(resume_after)
@@ -145,8 +156,7 @@ def _run_aggregation_cmd(self, session, explicit_session):
145156
aggregation_collection, cursor, sock_info.address,
146157
batch_size=self._batch_size or 0,
147158
max_await_time_ms=self._max_await_time_ms,
148-
session=session, explicit_session=explicit_session
149-
)
159+
session=session, explicit_session=explicit_session)
150160

151161
def _create_cursor(self):
152162
with self._database.client._tmp_session(self._session, close=False) as s:
@@ -263,6 +273,9 @@ def try_next(self):
263273
"token is missing.")
264274
self._resume_token = copy.copy(resume_token)
265275
self._start_at_operation_time = None
276+
277+
if self._decode_custom:
278+
return _bson_to_dict(change.raw, self._orig_codec_options)
266279
return change
267280

268281
def __enter__(self):

test/test_custom_types.py

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import datetime
1818
import sys
1919
import tempfile
20+
from collections import OrderedDict
2021
from decimal import Decimal
2122
from random import random
2223

@@ -36,16 +37,18 @@
3637
TypeEncoder, TypeRegistry)
3738
from bson.errors import InvalidDocument
3839
from bson.int64 import Int64
40+
from bson.raw_bson import RawBSONDocument
3941
from bson.py3compat import text_type
4042

4143
from gridfs import GridIn, GridOut
4244

4345
from pymongo.collection import ReturnDocument
4446
from pymongo.errors import DuplicateKeyError
47+
from pymongo.message import _CursorAddress
4548

4649
from test import client_context, unittest
4750
from test.test_client import IntegrationTest
48-
from test.utils import ignore_deprecations
51+
from test.utils import ignore_deprecations, rs_client
4952

5053

5154
class DecimalEncoder(TypeEncoder):
@@ -115,6 +118,14 @@ def transform_bson(self, value):
115118
[UppercaseTextDecoder(),]))
116119

117120

121+
def type_obfuscating_decoder_factory(rt_type):
122+
class ResumeTokenToNanDecoder(TypeDecoder):
123+
bson_type = rt_type
124+
def transform_bson(self, value):
125+
return "NaN"
126+
return ResumeTokenToNanDecoder
127+
128+
118129
class CustomBSONTypeTests(object):
119130
def roundtrip(self, doc):
120131
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
@@ -549,7 +560,7 @@ def test_command_errors_w_custom_type_decoder(self):
549560
def test_find_w_custom_type_decoder(self):
550561
db = self.db
551562
input_docs = [
552-
{'x': Int64(k)} for k in [1.0, 2.0, 3.0]]
563+
{'x': Int64(k)} for k in [1, 2, 3]]
553564
for doc in input_docs:
554565
db.test.insert_one(doc)
555566

@@ -558,6 +569,24 @@ def test_find_w_custom_type_decoder(self):
558569
for doc in test.find({}, batch_size=1):
559570
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
560571

572+
def test_find_w_custom_type_decoder_and_document_class(self):
573+
def run_test(doc_cls):
574+
db = self.db
575+
input_docs = [
576+
{'x': Int64(k)} for k in [1, 2, 3]]
577+
for doc in input_docs:
578+
db.test.insert_one(doc)
579+
580+
test = db.get_collection('test', codec_options=CodecOptions(
581+
type_registry=TypeRegistry([UndecipherableIntDecoder()]),
582+
document_class=doc_cls))
583+
for doc in test.find({}, batch_size=1):
584+
self.assertIsInstance(doc, doc_cls)
585+
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
586+
587+
for doc_cls in [RawBSONDocument, OrderedDict]:
588+
run_test(doc_cls)
589+
561590
@client_context.require_version_max(4, 1, 0, -1)
562591
def test_group_w_custom_type(self):
563592
db = self.db
@@ -709,5 +738,155 @@ def test_grid_out_custom_opts(self):
709738
self.assertRaises(AttributeError, setattr, two, attr, 5)
710739

711740

741+
class ChangeStreamsWCustomTypesTestMixin(object):
742+
def change_stream(self, *args, **kwargs):
743+
return self.watched_target.watch(*args, **kwargs)
744+
745+
def insert_and_check(self, change_stream, insert_doc,
746+
expected_doc):
747+
self.input_target.insert_one(insert_doc)
748+
change = next(change_stream)
749+
self.assertEqual(change['fullDocument'], expected_doc)
750+
751+
def kill_change_stream_cursor(self, change_stream):
752+
# Cause a cursor not found error on the next getMore.
753+
cursor = change_stream._cursor
754+
address = _CursorAddress(cursor.address, cursor._CommandCursor__ns)
755+
client = self.input_target.database.client
756+
client._close_cursor_now(cursor.cursor_id, address)
757+
758+
def test_simple(self):
759+
codecopts = CodecOptions(type_registry=TypeRegistry([
760+
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
761+
self.create_targets(codec_options=codecopts)
762+
763+
input_docs = [
764+
{'_id': UndecipherableInt64Type(1), 'data': 'hello'},
765+
{'_id': 2, 'data': 'world'},
766+
{'_id': UndecipherableInt64Type(3), 'data': '!'},]
767+
expected_docs = [
768+
{'_id': 1, 'data': 'HELLO'},
769+
{'_id': 2, 'data': 'WORLD'},
770+
{'_id': 3, 'data': '!'},]
771+
772+
change_stream = self.change_stream()
773+
774+
self.insert_and_check(change_stream, input_docs[0], expected_docs[0])
775+
self.kill_change_stream_cursor(change_stream)
776+
self.insert_and_check(change_stream, input_docs[1], expected_docs[1])
777+
self.kill_change_stream_cursor(change_stream)
778+
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
779+
780+
def test_break_resume_token(self):
781+
# Get one document from a change stream to determine resumeToken type.
782+
self.create_targets()
783+
change_stream = self.change_stream()
784+
self.input_target.insert_one({"data": "test"})
785+
change = next(change_stream)
786+
resume_token_decoder = type_obfuscating_decoder_factory(
787+
type(change['_id']['_data']))
788+
789+
# Custom-decoding the resumeToken type breaks resume tokens.
790+
codecopts = CodecOptions(type_registry=TypeRegistry([
791+
resume_token_decoder(), UndecipherableIntEncoder()]))
792+
793+
# Re-create targets, change stream and proceed.
794+
self.create_targets(codec_options=codecopts)
795+
796+
docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}]
797+
798+
change_stream = self.change_stream()
799+
self.insert_and_check(change_stream, docs[0], docs[0])
800+
self.kill_change_stream_cursor(change_stream)
801+
self.insert_and_check(change_stream, docs[1], docs[1])
802+
self.kill_change_stream_cursor(change_stream)
803+
self.insert_and_check(change_stream, docs[2], docs[2])
804+
805+
def test_document_class(self):
806+
def run_test(doc_cls):
807+
codecopts = CodecOptions(type_registry=TypeRegistry([
808+
UppercaseTextDecoder(), UndecipherableIntEncoder()]),
809+
document_class=doc_cls)
810+
811+
self.create_targets(codec_options=codecopts)
812+
change_stream = self.change_stream()
813+
814+
doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'}
815+
self.input_target.insert_one(doc)
816+
change = next(change_stream)
817+
818+
self.assertIsInstance(change, doc_cls)
819+
self.assertEqual(change['fullDocument']['a'], 101)
820+
self.assertEqual(change['fullDocument']['b'], 'XYZ')
821+
822+
for doc_cls in [OrderedDict, RawBSONDocument]:
823+
run_test(doc_cls)
824+
825+
826+
class TestCollectionChangeStreamsWCustomTypes(
827+
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
828+
@classmethod
829+
@client_context.require_version_min(3, 6, 0)
830+
@client_context.require_no_mmap
831+
@client_context.require_no_standalone
832+
def setUpClass(cls):
833+
super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass()
834+
835+
def tearDown(self):
836+
self.input_target.drop()
837+
838+
def create_targets(self, *args, **kwargs):
839+
self.watched_target = self.db.get_collection(
840+
'test', *args, **kwargs)
841+
self.input_target = self.watched_target
842+
# Insert a record to ensure db, coll are created.
843+
self.input_target.insert_one({'data': 'dummy'})
844+
845+
846+
class TestDatabaseChangeStreamsWCustomTypes(
847+
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
848+
@classmethod
849+
@client_context.require_version_min(4, 0, 0)
850+
@client_context.require_no_mmap
851+
@client_context.require_no_standalone
852+
def setUpClass(cls):
853+
super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass()
854+
855+
def tearDown(self):
856+
self.input_target.drop()
857+
self.client.drop_database(self.watched_target)
858+
859+
def create_targets(self, *args, **kwargs):
860+
self.watched_target = self.client.get_database(
861+
self.db.name, *args, **kwargs)
862+
self.input_target = self.watched_target.test
863+
# Insert a record to ensure db, coll are created.
864+
self.input_target.insert_one({'data': 'dummy'})
865+
866+
867+
class TestClusterChangeStreamsWCustomTypes(
868+
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
869+
@classmethod
870+
@client_context.require_version_min(4, 0, 0)
871+
@client_context.require_no_mmap
872+
@client_context.require_no_standalone
873+
def setUpClass(cls):
874+
super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass()
875+
876+
def tearDown(self):
877+
self.input_target.drop()
878+
self.client.drop_database(self.db)
879+
880+
def create_targets(self, *args, **kwargs):
881+
codec_options = kwargs.pop('codec_options', None)
882+
if codec_options:
883+
kwargs['type_registry'] = codec_options.type_registry
884+
kwargs['document_class'] = codec_options.document_class
885+
self.watched_target = rs_client(*args, **kwargs)
886+
self.input_target = self.watched_target[self.db.name].test
887+
# Insert a record to ensure db, coll are created.
888+
self.input_target.insert_one({'data': 'dummy'})
889+
890+
712891
if __name__ == "__main__":
713892
unittest.main()

0 commit comments

Comments
 (0)