Skip to content

Commit 19f9639

Browse files
author
Divjot Arora
committed
GODRIVER-1751 Ensure codecs that cache are not shared across registries (#508)
1 parent 44a08b7 commit 19f9639

8 files changed

+142
-18
lines changed

bson/bsoncodec/default_value_decoders.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ import (
2424

2525
var defaultValueDecoders DefaultValueDecoders
2626

27+
func newDefaultStructCodec() *StructCodec {
28+
codec, err := NewStructCodec(DefaultStructTagParser)
29+
if err != nil {
30+
// This function is called from the codec registration path, so errors can't be propagated. If there's an error
31+
// constructing the StructCodec, we panic to avoid losing it.
32+
panic(fmt.Errorf("error creating default StructCodec: %v", err))
33+
}
34+
return codec
35+
}
36+
2737
// DefaultValueDecoders is a namespace type for the default ValueDecoders used
2838
// when creating a registry.
2939
type DefaultValueDecoders struct{}
@@ -77,7 +87,7 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) {
7787
RegisterDefaultDecoder(reflect.Map, defaultMapCodec).
7888
RegisterDefaultDecoder(reflect.Slice, defaultSliceCodec).
7989
RegisterDefaultDecoder(reflect.String, defaultStringCodec).
80-
RegisterDefaultDecoder(reflect.Struct, defaultStructCodec).
90+
RegisterDefaultDecoder(reflect.Struct, newDefaultStructCodec()).
8191
RegisterDefaultDecoder(reflect.Ptr, NewPointerCodec()).
8292
RegisterTypeMapEntry(bsontype.Double, tFloat64).
8393
RegisterTypeMapEntry(bsontype.String, tString).

bson/bsoncodec/default_value_decoders_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ import (
2727
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
2828
)
2929

30+
var (
31+
defaultTestStructCodec = newDefaultStructCodec()
32+
)
33+
3034
func TestDefaultValueDecoders(t *testing.T) {
3135
var dvd DefaultValueDecoders
3236
var wrong = func(string, string) string { return "wrong" }
@@ -2197,7 +2201,7 @@ func TestDefaultValueDecoders(t *testing.T) {
21972201
},
21982202
{
21992203
"StructCodec.DecodeValue",
2200-
defaultStructCodec,
2204+
defaultTestStructCodec,
22012205
[]subtest{
22022206
{
22032207
"Not struct",
@@ -3497,7 +3501,7 @@ func TestDefaultValueDecoders(t *testing.T) {
34973501
emptyInterfaceStruct{},
34983502
bsonrw.NewBSONDocumentReader(docBytes),
34993503
emptyInterfaceErrorRegistry,
3500-
defaultStructCodec,
3504+
defaultTestStructCodec,
35013505
emptyInterfaceStructErr,
35023506
},
35033507
{
@@ -3508,15 +3512,15 @@ func TestDefaultValueDecoders(t *testing.T) {
35083512
stringStruct{},
35093513
bsonrw.NewBSONDocumentReader(docBytes),
35103514
NewRegistryBuilder().Build(),
3511-
defaultStructCodec,
3515+
defaultTestStructCodec,
35123516
stringStructErr,
35133517
},
35143518
{
35153519
"deeply nested struct",
35163520
outer{},
35173521
bsonrw.NewBSONDocumentReader(outerDoc),
35183522
nestedRegistry,
3519-
defaultStructCodec,
3523+
defaultTestStructCodec,
35203524
nestedErr,
35213525
},
35223526
}
@@ -3546,7 +3550,7 @@ func TestDefaultValueDecoders(t *testing.T) {
35463550
dc := DecodeContext{Registry: buildDefaultRegistry()}
35473551
vr := bsonrw.NewBSONDocumentReader(outerBytes)
35483552
val := reflect.New(reflect.TypeOf(outer{})).Elem()
3549-
err := defaultStructCodec.DecodeValue(dc, vr, val)
3553+
err := defaultTestStructCodec.DecodeValue(dc, vr, val)
35503554

35513555
decodeErr, ok := err.(*DecodeError)
35523556
assert.True(t, ok, "expected DecodeError, got %v of type %T", err, err)

bson/bsoncodec/default_value_encoders.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) {
104104
RegisterDefaultEncoder(reflect.Map, defaultMapCodec).
105105
RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec).
106106
RegisterDefaultEncoder(reflect.String, defaultStringCodec).
107-
RegisterDefaultEncoder(reflect.Struct, defaultStructCodec).
107+
RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()).
108108
RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()).
109109
RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)).
110110
RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)).

bson/bsoncodec/default_value_encoders_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ func TestDefaultValueEncoders(t *testing.T) {
10841084
},
10851085
{
10861086
"StructEncodeValue",
1087-
defaultStructCodec,
1087+
defaultTestStructCodec,
10881088
[]subtest{
10891089
{
10901090
"interface value",

bson/bsoncodec/pointer_codec.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ import (
1414
"go.mongodb.org/mongo-driver/bson/bsontype"
1515
)
1616

17-
var defaultPointerCodec = &PointerCodec{
18-
ecache: make(map[reflect.Type]ValueEncoder),
19-
dcache: make(map[reflect.Type]ValueDecoder),
20-
}
21-
2217
var _ ValueEncoder = &PointerCodec{}
2318
var _ ValueDecoder = &PointerCodec{}
2419

bson/bsoncodec/struct_codec.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@ import (
1919
"go.mongodb.org/mongo-driver/bson/bsontype"
2020
)
2121

22-
var defaultStructCodec = &StructCodec{
23-
cache: make(map[reflect.Type]*structDescription),
24-
parser: DefaultStructTagParser,
25-
}
26-
2722
// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
2823
type DecodeError struct {
2924
keys []string

bson/marshal_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@ package bson
88

99
import (
1010
"bytes"
11+
"fmt"
12+
"reflect"
1113
"testing"
1214

1315
"github.com/google/go-cmp/cmp"
1416
"github.com/stretchr/testify/require"
1517
"go.mongodb.org/mongo-driver/bson/bsoncodec"
18+
"go.mongodb.org/mongo-driver/bson/bsonrw"
1619
"go.mongodb.org/mongo-driver/bson/primitive"
20+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
21+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1722
)
1823

1924
func TestMarshalAppendWithRegistry(t *testing.T) {
@@ -207,3 +212,58 @@ func TestMarshal_roundtripFromDoc(t *testing.T) {
207212
t.Errorf("Documents to not match. got %v; want %v", after, before)
208213
}
209214
}
215+
216+
func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) {
217+
// Encoders that have caches for recursive encoder lookup should not be shared across Registry instances. Otherwise,
218+
// the first EncodeValue call would cache an encoder and a subsequent call would see that encoder even if a
219+
// different Registry is used.
220+
221+
// Create a custom Registry that negates int32 values when encoding.
222+
var encodeInt32 bsoncodec.ValueEncoderFunc = func(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
223+
if val.Kind() != reflect.Int32 {
224+
return fmt.Errorf("expected kind to be int32, got %v", val.Kind())
225+
}
226+
227+
return vw.WriteInt32(int32(val.Int()) * -1)
228+
}
229+
customReg := NewRegistryBuilder().
230+
RegisterTypeEncoder(tInt32, encodeInt32).
231+
Build()
232+
233+
// Helper function to run the test and make assertions. The provided original value should result in the document
234+
// {"x": {$numberInt: 1}} when marshalled with the default registry.
235+
verifyResults := func(t *testing.T, original interface{}) {
236+
// Marshal using the default and custom registries. Assert that the result is {x: 1} and {x: -1}, respectively.
237+
238+
first, err := Marshal(original)
239+
assert.Nil(t, err, "Marshal error: %v", err)
240+
expectedFirst := Raw(bsoncore.BuildDocumentFromElements(
241+
nil,
242+
bsoncore.AppendInt32Element(nil, "x", 1),
243+
))
244+
assert.Equal(t, expectedFirst, Raw(first), "expected document %v, got %v", expectedFirst, Raw(first))
245+
246+
second, err := MarshalWithRegistry(customReg, original)
247+
assert.Nil(t, err, "Marshal error: %v", err)
248+
expectedSecond := Raw(bsoncore.BuildDocumentFromElements(
249+
nil,
250+
bsoncore.AppendInt32Element(nil, "x", -1),
251+
))
252+
assert.Equal(t, expectedSecond, Raw(second), "expected document %v, got %v", expectedSecond, Raw(second))
253+
}
254+
255+
t.Run("struct", func(t *testing.T) {
256+
type Struct struct {
257+
X int32
258+
}
259+
verifyResults(t, Struct{
260+
X: 1,
261+
})
262+
})
263+
t.Run("pointer", func(t *testing.T) {
264+
i32 := int32(1)
265+
verifyResults(t, M{
266+
"x": &i32,
267+
})
268+
})
269+
}

bson/unmarshal_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/google/go-cmp/cmp"
1414
"go.mongodb.org/mongo-driver/bson/bsoncodec"
1515
"go.mongodb.org/mongo-driver/bson/bsonrw"
16+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
17+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1618
)
1719

1820
func TestUnmarshal(t *testing.T) {
@@ -106,3 +108,61 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) {
106108
}
107109
})
108110
}
111+
112+
func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) {
113+
// Decoders that have caches for recursive decoder lookup should not be shared across Registry instances. Otherwise,
114+
// the first DecodeValue call would cache an decoder and a subsequent call would see that decoder even if a
115+
// different Registry is used.
116+
117+
// Create a custom Registry that negates BSON int32 values when decoding.
118+
var decodeInt32 bsoncodec.ValueDecoderFunc = func(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
119+
i32, err := vr.ReadInt32()
120+
if err != nil {
121+
return err
122+
}
123+
124+
val.SetInt(int64(-1 * i32))
125+
return nil
126+
}
127+
customReg := NewRegistryBuilder().
128+
RegisterTypeDecoder(tInt32, bsoncodec.ValueDecoderFunc(decodeInt32)).
129+
Build()
130+
131+
docBytes := bsoncore.BuildDocumentFromElements(
132+
nil,
133+
bsoncore.AppendInt32Element(nil, "x", 1),
134+
)
135+
136+
// For all sub-tests, unmarshal docBytes into a struct and assert that value for "x" is 1 when using the default
137+
// registry and -1 when using the custom registry.
138+
t.Run("struct", func(t *testing.T) {
139+
type Struct struct {
140+
X int32
141+
}
142+
143+
var first Struct
144+
err := Unmarshal(docBytes, &first)
145+
assert.Nil(t, err, "Unmarshal error: %v", err)
146+
assert.Equal(t, int32(1), first.X, "expected X value to be 1, got %v", first.X)
147+
148+
var second Struct
149+
err = UnmarshalWithRegistry(customReg, docBytes, &second)
150+
assert.Nil(t, err, "Unmarshal error: %v", err)
151+
assert.Equal(t, int32(-1), second.X, "expected X value to be -1, got %v", second.X)
152+
})
153+
t.Run("pointer", func(t *testing.T) {
154+
type Struct struct {
155+
X *int32
156+
}
157+
158+
var first Struct
159+
err := Unmarshal(docBytes, &first)
160+
assert.Nil(t, err, "Unmarshal error: %v", err)
161+
assert.Equal(t, int32(1), *first.X, "expected X value to be 1, got %v", *first.X)
162+
163+
var second Struct
164+
err = UnmarshalWithRegistry(customReg, docBytes, &second)
165+
assert.Nil(t, err, "Unmarshal error: %v", err)
166+
assert.Equal(t, int32(-1), *second.X, "expected X value to be -1, got %v", *second.X)
167+
})
168+
}

0 commit comments

Comments
 (0)