@@ -22,12 +22,15 @@ import (
22
22
23
23
"time"
24
24
25
+ "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
26
+ "github.com/mongodb/mongo-go-driver/bson/bsonrw"
25
27
"github.com/mongodb/mongo-go-driver/mongo/options"
26
28
"github.com/mongodb/mongo-go-driver/mongo/readpref"
27
29
"github.com/mongodb/mongo-go-driver/mongo/writeconcern"
28
30
"github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
29
31
"github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid"
30
32
"github.com/mongodb/mongo-go-driver/x/network/connstring"
33
+ "reflect"
31
34
)
32
35
33
36
func createTestClient (t * testing.T ) * Client {
@@ -94,6 +97,62 @@ func TestClientOptions(t *testing.T) {
94
97
require .Equal (t , "test" , c .connString .ReplicaSet )
95
98
}
96
99
100
+ type NewCodec struct {
101
+ ID int64 `bson:"_id"`
102
+ }
103
+
104
+ func (e * NewCodec ) EncodeValue (ectx bsoncodec.EncodeContext , vw bsonrw.ValueWriter , val reflect.Value ) error {
105
+ return vw .WriteInt64 (val .Int ())
106
+ }
107
+
108
+ // DecodeValue negates the value of ID when reading
109
+ func (e * NewCodec ) DecodeValue (ectx bsoncodec.DecodeContext , vr bsonrw.ValueReader , val reflect.Value ) error {
110
+ i , err := vr .ReadInt64 ()
111
+ if err != nil {
112
+ return err
113
+ }
114
+
115
+ val .SetInt (i * - 1 )
116
+ return nil
117
+ }
118
+
119
+ func TestClientRegistryPassedToCursors (t * testing.T ) {
120
+ // register a new codec for the int64 type that does the default encoding for an int64 and negates the value when
121
+ // decoding
122
+
123
+ rb := bson .NewRegistryBuilder ()
124
+ cod := & NewCodec {}
125
+ rb .RegisterCodec (reflect .TypeOf (int64 (0 )), cod )
126
+
127
+ cs := testutil .ConnString (t )
128
+ client , err := NewClientWithOptions (cs .String (), options .Client ().SetRegistry (rb .Build ()))
129
+ require .NoError (t , err )
130
+ err = client .Connect (ctx )
131
+ require .NoError (t , err )
132
+
133
+ db := client .Database ("TestRegistryDB" )
134
+ defer func () {
135
+ _ = db .Drop (ctx )
136
+ _ = client .Disconnect (ctx )
137
+ }()
138
+
139
+ coll := db .Collection ("TestRegistryColl" )
140
+
141
+ _ , err = coll .InsertOne (ctx , NewCodec {ID : 10 })
142
+ require .NoError (t , err )
143
+
144
+ c , err := coll .Find (ctx , nil )
145
+ require .NoError (t , err )
146
+
147
+ require .True (t , c .Next (ctx ))
148
+
149
+ var foundDoc NewCodec
150
+ err = c .Decode (& foundDoc )
151
+ require .NoError (t , err )
152
+
153
+ require .Equal (t , foundDoc .ID , int64 (- 10 ))
154
+ }
155
+
97
156
func TestClient_TLSConnection (t * testing.T ) {
98
157
skipIfBelow30 (t ) // 3.0 doesn't return a security field in the serverStatus response
99
158
t .Parallel ()
0 commit comments