Skip to content

Commit 8f19bdb

Browse files
GODRIVER-2407 Add DocumentDecodeType to DecodeContext (#951)
Co-authored-by: Benjamin Rewis <[email protected]> Co-authored-by: Benjamin Rewis <[email protected]>
1 parent baa2dac commit 8f19bdb

File tree

4 files changed

+141
-9
lines changed

4 files changed

+141
-9
lines changed

bson/bsoncodec/bsoncodec.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
"go.mongodb.org/mongo-driver/bson/bsonrw"
1515
"go.mongodb.org/mongo-driver/bson/bsontype"
16+
"go.mongodb.org/mongo-driver/bson/primitive"
1617
)
1718

1819
var (
@@ -118,11 +119,32 @@ type EncodeContext struct {
118119
type DecodeContext struct {
119120
*Registry
120121
Truncate bool
122+
121123
// Ancestor is the type of a containing document. This is mainly used to determine what type
122124
// should be used when decoding an embedded document into an empty interface. For example, if
123125
// Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface
124126
// will be decoded into a bson.M.
127+
//
128+
// Deprecated: Use DefaultDocumentM or DefaultDocumentD instead.
125129
Ancestor reflect.Type
130+
131+
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
132+
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
133+
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
134+
// error. DocumentType overrides the Ancestor field.
135+
defaultDocumentType reflect.Type
136+
}
137+
138+
// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as
139+
// "interface{}" or "map[string]interface{}".
140+
func (dc *DecodeContext) DefaultDocumentM() {
141+
dc.defaultDocumentType = reflect.TypeOf(primitive.M{})
142+
}
143+
144+
// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as
145+
// "interface{}" or "map[string]interface{}".
146+
func (dc *DecodeContext) DefaultDocumentD() {
147+
dc.defaultDocumentType = reflect.TypeOf(primitive.D{})
126148
}
127149

128150
// ValueCodec is the interface that groups the methods to encode and decode

bson/bsoncodec/empty_interface_codec.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,18 @@ func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWrit
5757

5858
func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType bsontype.Type) (reflect.Type, error) {
5959
isDocument := valueType == bsontype.Type(0) || valueType == bsontype.EmbeddedDocument
60-
if isDocument && dc.Ancestor != nil {
61-
// Using ancestor information rather than looking up the type map entry forces consistent decoding.
62-
// If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry
63-
// has been registered.
64-
return dc.Ancestor, nil
60+
if isDocument {
61+
if dc.defaultDocumentType != nil {
62+
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
63+
// that type.
64+
return dc.defaultDocumentType, nil
65+
}
66+
if dc.Ancestor != nil {
67+
// Using ancestor information rather than looking up the type map entry forces consistent decoding.
68+
// If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry
69+
// has been registered.
70+
return dc.Ancestor, nil
71+
}
6572
}
6673

6774
rtype, err := dc.LookupTypeMapEntry(valueType)

bson/decoder.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ var decPool = sync.Pool{
3333
type Decoder struct {
3434
dc bsoncodec.DecodeContext
3535
vr bsonrw.ValueReader
36+
37+
// We persist defaultDocumentM and defaultDocumentD on the Decoder to prevent overwriting from
38+
// (*Decoder).SetContext.
39+
defaultDocumentM bool
40+
defaultDocumentD bool
3641
}
3742

3843
// NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr.
@@ -95,6 +100,12 @@ func (d *Decoder) Decode(val interface{}) error {
95100
if err != nil {
96101
return err
97102
}
103+
if d.defaultDocumentM {
104+
d.dc.DefaultDocumentM()
105+
}
106+
if d.defaultDocumentD {
107+
d.dc.DefaultDocumentD()
108+
}
98109
return decoder.DecodeValue(d.dc, d.vr, rval)
99110
}
100111

@@ -116,3 +127,15 @@ func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error {
116127
d.dc = dc
117128
return nil
118129
}
130+
131+
// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as
132+
// "interface{}" or "map[string]interface{}".
133+
func (d *Decoder) DefaultDocumentM() {
134+
d.defaultDocumentM = true
135+
}
136+
137+
// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as
138+
// "interface{}" or "map[string]interface{}".
139+
func (d *Decoder) DefaultDocumentD() {
140+
d.defaultDocumentD = true
141+
}

bson/decoder_test.go

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"go.mongodb.org/mongo-driver/bson/bsonrw"
1919
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
2020
"go.mongodb.org/mongo-driver/bson/bsontype"
21+
"go.mongodb.org/mongo-driver/bson/primitive"
2122
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
2223
)
2324

@@ -199,12 +200,12 @@ func TestDecoderv2(t *testing.T) {
199200
dc2 := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
200201
dec, err := NewDecoderWithContext(dc1, bsonrw.NewBSONDocumentReader([]byte{}))
201202
noerr(t, err)
202-
if dec.dc != dc1 {
203+
if !reflect.DeepEqual(dec.dc, dc1) {
203204
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
204205
}
205206
err = dec.SetContext(dc2)
206207
noerr(t, err)
207-
if dec.dc != dc2 {
208+
if !reflect.DeepEqual(dec.dc, dc2) {
208209
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
209210
}
210211
})
@@ -214,12 +215,12 @@ func TestDecoderv2(t *testing.T) {
214215
dc2 := bsoncodec.DecodeContext{Registry: r2}
215216
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{}))
216217
noerr(t, err)
217-
if dec.dc != dc1 {
218+
if !reflect.DeepEqual(dec.dc, dc1) {
218219
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
219220
}
220221
err = dec.SetRegistry(r2)
221222
noerr(t, err)
222-
if dec.dc != dc2 {
223+
if !reflect.DeepEqual(dec.dc, dc2) {
223224
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
224225
}
225226
})
@@ -235,6 +236,85 @@ func TestDecoderv2(t *testing.T) {
235236
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
236237
}
237238
})
239+
t.Run("SetDocumentType embedded map as empty interface", func(t *testing.T) {
240+
type someMap map[string]interface{}
241+
242+
in := make(someMap)
243+
in["foo"] = map[string]interface{}{"bar": "baz"}
244+
245+
bytes, err := Marshal(in)
246+
if err != nil {
247+
t.Fatal(err)
248+
}
249+
250+
var bsonOut someMap
251+
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes))
252+
if err != nil {
253+
t.Fatal(err)
254+
}
255+
dec.DefaultDocumentM()
256+
if err := dec.Decode(&bsonOut); err != nil {
257+
t.Fatal(err)
258+
}
259+
260+
// Ensure that interface{}-typed top-level data is converted to the document type.
261+
bsonOutType := reflect.TypeOf(bsonOut)
262+
inType := reflect.TypeOf(in)
263+
assert.Equal(t, inType, bsonOutType, "expected %v to equal %v", inType.String(), bsonOutType.String())
264+
265+
// Ensure that the embedded type is a primitive map.
266+
mType := reflect.TypeOf(primitive.M{})
267+
bsonFooOutType := reflect.TypeOf(bsonOut["foo"])
268+
assert.Equal(t, mType, bsonFooOutType, "expected %v to equal %v", mType.String(), bsonFooOutType.String())
269+
})
270+
t.Run("SetDocumentType for decoding into interface{} alias", func(t *testing.T) {
271+
var in interface{} = map[string]interface{}{"bar": "baz"}
272+
273+
bytes, err := Marshal(in)
274+
if err != nil {
275+
t.Fatal(err)
276+
}
277+
278+
var bsonOut interface{}
279+
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes))
280+
if err != nil {
281+
t.Fatal(err)
282+
}
283+
dec.DefaultDocumentD()
284+
if err := dec.Decode(&bsonOut); err != nil {
285+
t.Fatal(err)
286+
}
287+
288+
// Ensure that interface{}-typed top-level data is converted to the document type.
289+
dType := reflect.TypeOf(primitive.D{})
290+
bsonOutType := reflect.TypeOf(bsonOut)
291+
assert.Equal(t, dType, bsonOutType,
292+
"expected %v to equal %v", dType.String(), bsonOutType.String())
293+
})
294+
t.Run("SetDocumentType for decoding into non-interface{} alias", func(t *testing.T) {
295+
var in interface{} = map[string]interface{}{"bar": "baz"}
296+
297+
bytes, err := Marshal(in)
298+
if err != nil {
299+
t.Fatal(err)
300+
}
301+
302+
var bsonOut struct{}
303+
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes))
304+
if err != nil {
305+
t.Fatal(err)
306+
}
307+
dec.DefaultDocumentD()
308+
if err := dec.Decode(&bsonOut); err != nil {
309+
t.Fatal(err)
310+
}
311+
312+
// Ensure that typed top-level data is not converted to the document type.
313+
dType := reflect.TypeOf(primitive.D{})
314+
bsonOutType := reflect.TypeOf(bsonOut)
315+
assert.NotEqual(t, dType, bsonOutType,
316+
"expected %v to not equal %v", dType.String(), bsonOutType.String())
317+
})
238318
}
239319

240320
type testUnmarshaler struct {

0 commit comments

Comments
 (0)