Skip to content

Commit 9cca2a7

Browse files
committed
PYTHON-1818 Support custom type encoding in watch pipelines
1 parent 0ea5a15 commit 9cca2a7

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pymongo/change_stream.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def __init__(self, target, pipeline, full_document, resume_after,
6666
self._orig_codec_options = target.codec_options
6767
if target.codec_options.type_registry._decoder_map:
6868
self._decode_custom = True
69+
# Keep the type registry so that we support encoding custom types
70+
# in the pipeline.
6971
self._target = target.with_options(
7072
codec_options=target.codec_options.with_options(
71-
document_class=RawBSONDocument, type_registry=None))
73+
document_class=RawBSONDocument))
7274
else:
7375
self._target = target
7476

test/test_custom_types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,29 @@ def test_simple(self):
777777
self.kill_change_stream_cursor(change_stream)
778778
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
779779

780+
def test_custom_type_in_pipeline(self):
781+
codecopts = CodecOptions(type_registry=TypeRegistry([
782+
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
783+
self.create_targets(codec_options=codecopts)
784+
785+
input_docs = [
786+
{'_id': UndecipherableInt64Type(1), 'data': 'hello'},
787+
{'_id': 2, 'data': 'world'},
788+
{'_id': UndecipherableInt64Type(3), 'data': '!'}]
789+
expected_docs = [
790+
{'_id': 2, 'data': 'WORLD'},
791+
{'_id': 3, 'data': '!'}]
792+
793+
# UndecipherableInt64Type should be encoded with the TypeRegistry.
794+
change_stream = self.change_stream(
795+
[{'$match': {'documentKey._id': {
796+
'$gte': UndecipherableInt64Type(2)}}}])
797+
798+
self.input_target.insert_one(input_docs[0])
799+
self.insert_and_check(change_stream, input_docs[1], expected_docs[0])
800+
self.kill_change_stream_cursor(change_stream)
801+
self.insert_and_check(change_stream, input_docs[2], expected_docs[1])
802+
780803
def test_break_resume_token(self):
781804
# Get one document from a change stream to determine resumeToken type.
782805
self.create_targets()

0 commit comments

Comments
 (0)