17
17
import datetime
18
18
import sys
19
19
import tempfile
20
+ from collections import OrderedDict
20
21
from decimal import Decimal
21
22
from random import random
22
23
36
37
TypeEncoder , TypeRegistry )
37
38
from bson .errors import InvalidDocument
38
39
from bson .int64 import Int64
40
+ from bson .raw_bson import RawBSONDocument
39
41
from bson .py3compat import text_type
40
42
41
43
from gridfs import GridIn , GridOut
42
44
43
45
from pymongo .collection import ReturnDocument
44
46
from pymongo .errors import DuplicateKeyError
47
+ from pymongo .message import _CursorAddress
45
48
46
49
from test import client_context , unittest
47
50
from test .test_client import IntegrationTest
48
- from test .utils import ignore_deprecations
51
+ from test .utils import ignore_deprecations , rs_client
49
52
50
53
51
54
class DecimalEncoder (TypeEncoder ):
@@ -115,6 +118,14 @@ def transform_bson(self, value):
115
118
[UppercaseTextDecoder (),]))
116
119
117
120
121
+ def type_obfuscating_decoder_factory (rt_type ):
122
+ class ResumeTokenToNanDecoder (TypeDecoder ):
123
+ bson_type = rt_type
124
+ def transform_bson (self , value ):
125
+ return "NaN"
126
+ return ResumeTokenToNanDecoder
127
+
128
+
118
129
class CustomBSONTypeTests (object ):
119
130
def roundtrip (self , doc ):
120
131
bsonbytes = BSON ().encode (doc , codec_options = self .codecopts )
@@ -549,7 +560,7 @@ def test_command_errors_w_custom_type_decoder(self):
549
560
def test_find_w_custom_type_decoder (self ):
550
561
db = self .db
551
562
input_docs = [
552
- {'x' : Int64 (k )} for k in [1.0 , 2.0 , 3.0 ]]
563
+ {'x' : Int64 (k )} for k in [1 , 2 , 3 ]]
553
564
for doc in input_docs :
554
565
db .test .insert_one (doc )
555
566
@@ -558,6 +569,24 @@ def test_find_w_custom_type_decoder(self):
558
569
for doc in test .find ({}, batch_size = 1 ):
559
570
self .assertIsInstance (doc ['x' ], UndecipherableInt64Type )
560
571
572
+ def test_find_w_custom_type_decoder_and_document_class (self ):
573
+ def run_test (doc_cls ):
574
+ db = self .db
575
+ input_docs = [
576
+ {'x' : Int64 (k )} for k in [1 , 2 , 3 ]]
577
+ for doc in input_docs :
578
+ db .test .insert_one (doc )
579
+
580
+ test = db .get_collection ('test' , codec_options = CodecOptions (
581
+ type_registry = TypeRegistry ([UndecipherableIntDecoder ()]),
582
+ document_class = doc_cls ))
583
+ for doc in test .find ({}, batch_size = 1 ):
584
+ self .assertIsInstance (doc , doc_cls )
585
+ self .assertIsInstance (doc ['x' ], UndecipherableInt64Type )
586
+
587
+ for doc_cls in [RawBSONDocument , OrderedDict ]:
588
+ run_test (doc_cls )
589
+
561
590
@client_context .require_version_max (4 , 1 , 0 , - 1 )
562
591
def test_group_w_custom_type (self ):
563
592
db = self .db
@@ -709,5 +738,155 @@ def test_grid_out_custom_opts(self):
709
738
self .assertRaises (AttributeError , setattr , two , attr , 5 )
710
739
711
740
741
+ class ChangeStreamsWCustomTypesTestMixin (object ):
742
+ def change_stream (self , * args , ** kwargs ):
743
+ return self .watched_target .watch (* args , ** kwargs )
744
+
745
+ def insert_and_check (self , change_stream , insert_doc ,
746
+ expected_doc ):
747
+ self .input_target .insert_one (insert_doc )
748
+ change = next (change_stream )
749
+ self .assertEqual (change ['fullDocument' ], expected_doc )
750
+
751
+ def kill_change_stream_cursor (self , change_stream ):
752
+ # Cause a cursor not found error on the next getMore.
753
+ cursor = change_stream ._cursor
754
+ address = _CursorAddress (cursor .address , cursor ._CommandCursor__ns )
755
+ client = self .input_target .database .client
756
+ client ._close_cursor_now (cursor .cursor_id , address )
757
+
758
+ def test_simple (self ):
759
+ codecopts = CodecOptions (type_registry = TypeRegistry ([
760
+ UndecipherableIntEncoder (), UppercaseTextDecoder ()]))
761
+ self .create_targets (codec_options = codecopts )
762
+
763
+ input_docs = [
764
+ {'_id' : UndecipherableInt64Type (1 ), 'data' : 'hello' },
765
+ {'_id' : 2 , 'data' : 'world' },
766
+ {'_id' : UndecipherableInt64Type (3 ), 'data' : '!' },]
767
+ expected_docs = [
768
+ {'_id' : 1 , 'data' : 'HELLO' },
769
+ {'_id' : 2 , 'data' : 'WORLD' },
770
+ {'_id' : 3 , 'data' : '!' },]
771
+
772
+ change_stream = self .change_stream ()
773
+
774
+ self .insert_and_check (change_stream , input_docs [0 ], expected_docs [0 ])
775
+ self .kill_change_stream_cursor (change_stream )
776
+ self .insert_and_check (change_stream , input_docs [1 ], expected_docs [1 ])
777
+ self .kill_change_stream_cursor (change_stream )
778
+ self .insert_and_check (change_stream , input_docs [2 ], expected_docs [2 ])
779
+
780
+ def test_break_resume_token (self ):
781
+ # Get one document from a change stream to determine resumeToken type.
782
+ self .create_targets ()
783
+ change_stream = self .change_stream ()
784
+ self .input_target .insert_one ({"data" : "test" })
785
+ change = next (change_stream )
786
+ resume_token_decoder = type_obfuscating_decoder_factory (
787
+ type (change ['_id' ]['_data' ]))
788
+
789
+ # Custom-decoding the resumeToken type breaks resume tokens.
790
+ codecopts = CodecOptions (type_registry = TypeRegistry ([
791
+ resume_token_decoder (), UndecipherableIntEncoder ()]))
792
+
793
+ # Re-create targets, change stream and proceed.
794
+ self .create_targets (codec_options = codecopts )
795
+
796
+ docs = [{'_id' : 1 }, {'_id' : 2 }, {'_id' : 3 }]
797
+
798
+ change_stream = self .change_stream ()
799
+ self .insert_and_check (change_stream , docs [0 ], docs [0 ])
800
+ self .kill_change_stream_cursor (change_stream )
801
+ self .insert_and_check (change_stream , docs [1 ], docs [1 ])
802
+ self .kill_change_stream_cursor (change_stream )
803
+ self .insert_and_check (change_stream , docs [2 ], docs [2 ])
804
+
805
+ def test_document_class (self ):
806
+ def run_test (doc_cls ):
807
+ codecopts = CodecOptions (type_registry = TypeRegistry ([
808
+ UppercaseTextDecoder (), UndecipherableIntEncoder ()]),
809
+ document_class = doc_cls )
810
+
811
+ self .create_targets (codec_options = codecopts )
812
+ change_stream = self .change_stream ()
813
+
814
+ doc = {'a' : UndecipherableInt64Type (101 ), 'b' : 'xyz' }
815
+ self .input_target .insert_one (doc )
816
+ change = next (change_stream )
817
+
818
+ self .assertIsInstance (change , doc_cls )
819
+ self .assertEqual (change ['fullDocument' ]['a' ], 101 )
820
+ self .assertEqual (change ['fullDocument' ]['b' ], 'XYZ' )
821
+
822
+ for doc_cls in [OrderedDict , RawBSONDocument ]:
823
+ run_test (doc_cls )
824
+
825
+
826
+ class TestCollectionChangeStreamsWCustomTypes (
827
+ IntegrationTest , ChangeStreamsWCustomTypesTestMixin ):
828
+ @classmethod
829
+ @client_context .require_version_min (3 , 6 , 0 )
830
+ @client_context .require_no_mmap
831
+ @client_context .require_no_standalone
832
+ def setUpClass (cls ):
833
+ super (TestCollectionChangeStreamsWCustomTypes , cls ).setUpClass ()
834
+
835
+ def tearDown (self ):
836
+ self .input_target .drop ()
837
+
838
+ def create_targets (self , * args , ** kwargs ):
839
+ self .watched_target = self .db .get_collection (
840
+ 'test' , * args , ** kwargs )
841
+ self .input_target = self .watched_target
842
+ # Insert a record to ensure db, coll are created.
843
+ self .input_target .insert_one ({'data' : 'dummy' })
844
+
845
+
846
+ class TestDatabaseChangeStreamsWCustomTypes (
847
+ IntegrationTest , ChangeStreamsWCustomTypesTestMixin ):
848
+ @classmethod
849
+ @client_context .require_version_min (4 , 0 , 0 )
850
+ @client_context .require_no_mmap
851
+ @client_context .require_no_standalone
852
+ def setUpClass (cls ):
853
+ super (TestDatabaseChangeStreamsWCustomTypes , cls ).setUpClass ()
854
+
855
+ def tearDown (self ):
856
+ self .input_target .drop ()
857
+ self .client .drop_database (self .watched_target )
858
+
859
+ def create_targets (self , * args , ** kwargs ):
860
+ self .watched_target = self .client .get_database (
861
+ self .db .name , * args , ** kwargs )
862
+ self .input_target = self .watched_target .test
863
+ # Insert a record to ensure db, coll are created.
864
+ self .input_target .insert_one ({'data' : 'dummy' })
865
+
866
+
867
+ class TestClusterChangeStreamsWCustomTypes (
868
+ IntegrationTest , ChangeStreamsWCustomTypesTestMixin ):
869
+ @classmethod
870
+ @client_context .require_version_min (4 , 0 , 0 )
871
+ @client_context .require_no_mmap
872
+ @client_context .require_no_standalone
873
+ def setUpClass (cls ):
874
+ super (TestClusterChangeStreamsWCustomTypes , cls ).setUpClass ()
875
+
876
+ def tearDown (self ):
877
+ self .input_target .drop ()
878
+ self .client .drop_database (self .db )
879
+
880
+ def create_targets (self , * args , ** kwargs ):
881
+ codec_options = kwargs .pop ('codec_options' , None )
882
+ if codec_options :
883
+ kwargs ['type_registry' ] = codec_options .type_registry
884
+ kwargs ['document_class' ] = codec_options .document_class
885
+ self .watched_target = rs_client (* args , ** kwargs )
886
+ self .input_target = self .watched_target [self .db .name ].test
887
+ # Insert a record to ensure db, coll are created.
888
+ self .input_target .insert_one ({'data' : 'dummy' })
889
+
890
+
712
891
if __name__ == "__main__" :
713
892
unittest .main ()
0 commit comments