Skip to content

Commit 2383f21

Browse files
author
iwysiu
authored
GODRIVER-1349 implement mgocompat StructCodec options (#231)
1 parent b4d25d2 commit 2383f21

File tree

5 files changed

+184
-38
lines changed

5 files changed

+184
-38
lines changed

bson/bsoncodec/struct_codec.go

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import (
1212
"reflect"
1313
"strings"
1414
"sync"
15+
"time"
1516

17+
"go.mongodb.org/mongo-driver/bson/bsonoptions"
1618
"go.mongodb.org/mongo-driver/bson/bsonrw"
1719
"go.mongodb.org/mongo-driver/bson/bsontype"
1820
)
@@ -31,24 +33,41 @@ type Zeroer interface {
3133

3234
// StructCodec is the Codec used for struct values.
3335
type StructCodec struct {
34-
cache map[reflect.Type]*structDescription
35-
l sync.RWMutex
36-
parser StructTagParser
36+
cache map[reflect.Type]*structDescription
37+
l sync.RWMutex
38+
parser StructTagParser
39+
DecodeZeroStruct bool
40+
DecodeDeepZeroInline bool
41+
EncodeOmitDefaultStruct bool
3742
}
3843

3944
var _ ValueEncoder = &StructCodec{}
4045
var _ ValueDecoder = &StructCodec{}
4146

4247
// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
43-
func NewStructCodec(p StructTagParser) (*StructCodec, error) {
48+
func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
4449
if p == nil {
4550
return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
4651
}
4752

48-
return &StructCodec{
53+
structOpt := bsonoptions.MergeStructCodecOptions(opts...)
54+
55+
codec := &StructCodec{
4956
cache: make(map[reflect.Type]*structDescription),
5057
parser: p,
51-
}, nil
58+
}
59+
60+
if structOpt.DecodeZeroStruct != nil {
61+
codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
62+
}
63+
if structOpt.DecodeDeepZeroInline != nil {
64+
codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
65+
}
66+
if structOpt.EncodeOmitDefaultStruct != nil {
67+
codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
68+
}
69+
70+
return codec, nil
5271
}
5372

5473
// EncodeValue handles encoding generic struct types.
@@ -138,6 +157,13 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r
138157
return err
139158
}
140159

160+
if sc.DecodeZeroStruct {
161+
val.Set(reflect.Zero(val.Type()))
162+
}
163+
if sc.DecodeDeepZeroInline && sd.inline {
164+
val.Set(deepZero(val.Type()))
165+
}
166+
141167
var decoder ValueDecoder
142168
var inlineMap reflect.Value
143169
if sd.inlineMap >= 0 {
@@ -257,6 +283,23 @@ func (sc *StructCodec) isZero(i interface{}) bool {
257283
return v.Float() == 0
258284
case reflect.Interface, reflect.Ptr:
259285
return v.IsNil()
286+
case reflect.Struct:
287+
if sc.EncodeOmitDefaultStruct {
288+
vt := v.Type()
289+
if vt == tTime {
290+
return v.Interface().(time.Time).IsZero()
291+
}
292+
for i := 0; i < v.NumField(); i++ {
293+
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
294+
continue // Private field
295+
}
296+
fld := v.Field(i)
297+
if !sc.isZero(fld.Interface()) {
298+
return false
299+
}
300+
}
301+
return true
302+
}
260303
}
261304

262305
return false
@@ -266,6 +309,7 @@ type structDescription struct {
266309
fm map[string]fieldDescription
267310
fl []fieldDescription
268311
inlineMap int
312+
inline bool
269313
}
270314

271315
type fieldDescription struct {
@@ -328,6 +372,7 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr
328372
description.truncate = stags.Truncate
329373

330374
if stags.Inline {
375+
sd.inline = true
331376
switch sfType.Kind() {
332377
case reflect.Map:
333378
if sd.inlineMap >= 0 {
@@ -416,3 +461,39 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
416461

417462
return fieldByIndexErr(val, index)
418463
}
464+
465+
// DeepZero returns recursive zero object
466+
func deepZero(st reflect.Type) (result reflect.Value) {
467+
result = reflect.Indirect(reflect.New(st))
468+
469+
if result.Kind() == reflect.Struct {
470+
for i := 0; i < result.NumField(); i++ {
471+
if f := result.Field(i); f.Kind() == reflect.Ptr {
472+
if f.CanInterface() {
473+
if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
474+
result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
475+
}
476+
}
477+
}
478+
}
479+
}
480+
481+
return
482+
}
483+
484+
// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
485+
func recursivePointerTo(v reflect.Value) reflect.Value {
486+
v = reflect.Indirect(v)
487+
result := reflect.New(v.Type())
488+
if v.Kind() == reflect.Struct {
489+
for i := 0; i < v.NumField(); i++ {
490+
if f := v.Field(i); f.Kind() == reflect.Ptr {
491+
if f.Elem().Kind() == reflect.Struct {
492+
result.Elem().Field(i).Set(recursivePointerTo(f))
493+
}
494+
}
495+
}
496+
}
497+
498+
return result
499+
}

bson/bsonoptions/string_codec_options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ package bsonoptions
88

99
var defaultDecodeOIDAsHex = true
1010

11-
// StringCodecOptions represents all possible options for time.Time encoding and decoding.
11+
// StringCodecOptions represents all possible options for string encoding and decoding.
1212
type StringCodecOptions struct {
1313
DecodeObjectIDAsHex *bool // Specifies if we should decode ObjectID as the hex value. Defaults to true.
1414
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bsonoptions
8+
9+
// StructCodecOptions represents all possible options for struct encoding and decoding.
10+
type StructCodecOptions struct {
11+
DecodeZeroStruct *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false.
12+
DecodeDeepZeroInline *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false.
13+
EncodeOmitDefaultStruct *bool // Specifies if default structs should be considered empty by omitempty. Defaults to false.
14+
}
15+
16+
// StructCodec creates a new *StructCodecOptions
17+
func StructCodec() *StructCodecOptions {
18+
return &StructCodecOptions{}
19+
}
20+
21+
// SetDecodeZeroStruct specifies if structs should be zeroed before decoding into them. Defaults to false.
22+
func (t *StructCodecOptions) SetDecodeZeroStruct(b bool) *StructCodecOptions {
23+
t.DecodeZeroStruct = &b
24+
return t
25+
}
26+
27+
// SetDecodeDeepZeroInline specifies if structs should be zeroed before decoding into them. Defaults to false.
28+
func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions {
29+
t.DecodeDeepZeroInline = &b
30+
return t
31+
}
32+
33+
// SetEncodeOmitDefaultStruct specifies if default structs should be considered empty by omitempty. A default struct has all
34+
// its values set to their default value. Defaults to false.
35+
func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOptions {
36+
t.EncodeOmitDefaultStruct = &b
37+
return t
38+
}
39+
40+
// MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion.
41+
func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions {
42+
s := StructCodec()
43+
for _, opt := range opts {
44+
if opt == nil {
45+
continue
46+
}
47+
48+
if opt.DecodeZeroStruct != nil {
49+
s.DecodeZeroStruct = opt.DecodeZeroStruct
50+
}
51+
if opt.DecodeDeepZeroInline != nil {
52+
s.DecodeDeepZeroInline = opt.DecodeDeepZeroInline
53+
}
54+
if opt.EncodeOmitDefaultStruct != nil {
55+
s.EncodeOmitDefaultStruct = opt.EncodeOmitDefaultStruct
56+
}
57+
}
58+
59+
return s
60+
}

bson/mgocompat/bson_test.go

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,16 @@ func TestUnmarshalRawIncompatible(t *testing.T) {
204204
assert.NotNil(t, err, "expected an error")
205205
}
206206

207-
// func TestUnmarshalZeroesStruct(t *testing.T) {
208-
// data, err := bson.MarshalWithRegistry(mgoRegistry, bson.M{"b": 2})
209-
// assert.Nil(t, err, "expected nil error, got: %v", err)
210-
// type T struct{ A, B int }
211-
// v := T{A: 1}
212-
// err = bson.UnmarshalWithRegistry(mgoRegistry, data, &v)
213-
// assert.Nil(t, err, "expected nil error, got: %v", err)
214-
// assert.Equal(t, 0, v.A, "expected: 0, got: %v", v.A)
215-
// assert.Equal(t, 2, v.B, "expected: 2, got: %v", v.B)
216-
// }
207+
func TestUnmarshalZeroesStruct(t *testing.T) {
208+
data, err := bson.MarshalWithRegistry(mgoRegistry, bson.M{"b": 2})
209+
assert.Nil(t, err, "expected nil error, got: %v", err)
210+
type T struct{ A, B int }
211+
v := T{A: 1}
212+
err = bson.UnmarshalWithRegistry(mgoRegistry, data, &v)
213+
assert.Nil(t, err, "expected nil error, got: %v", err)
214+
assert.Equal(t, 0, v.A, "expected: 0, got: %v", v.A)
215+
assert.Equal(t, 2, v.B, "expected: 2, got: %v", v.B)
216+
}
217217

218218
// func TestUnmarshalZeroesMap(t *testing.T) {
219219
// data, err := bson.MarshalWithRegistry(mgoRegistry, bson.M{"b": 2})
@@ -226,17 +226,17 @@ func TestUnmarshalRawIncompatible(t *testing.T) {
226226
// assert.True(t, reflect.DeepEqual(want, m), "expected: %v, got: %v", want, m)
227227
// }
228228

229-
// func TestUnmarshalNonNilInterface(t *testing.T) {
230-
// data, err := bson.MarshalWithRegistry(mgoRegistry, bson.M{"b": 2})
231-
// assert.Nil(t, err, "expected nil error, got: %v", err)
232-
// m := bson.M{"a": 1}
233-
// var i interface{}
234-
// i = m
235-
// err = bson.UnmarshalWithRegistry(mgoRegistry, data, &i)
236-
// assert.Nil(t, err, "expected nil error, got: %v", err)
237-
// assert.True(t, reflect.DeepEqual(bson.M{"b": 2}, i), "expected: %v, got: %v", bson.M{"b": 2}, i)
238-
// assert.True(t, reflect.DeepEqual(bson.M{"a": 1}, i), "expected: %v, got: %v", bson.M{"a": 1}, i)
239-
// }
229+
func TestUnmarshalNonNilInterface(t *testing.T) {
230+
data, err := bson.MarshalWithRegistry(mgoRegistry, bson.M{"b": 2})
231+
assert.Nil(t, err, "expected nil error, got: %v", err)
232+
m := bson.M{"a": 1}
233+
var i interface{}
234+
i = m
235+
err = bson.UnmarshalWithRegistry(mgoRegistry, data, &i)
236+
assert.Nil(t, err, "expected nil error, got: %v", err)
237+
assert.True(t, reflect.DeepEqual(bson.M{"b": 2}, i), "expected: %v, got: %v", bson.M{"b": 2}, i)
238+
assert.True(t, reflect.DeepEqual(bson.M{"a": 1}, m), "expected: %v, got: %v", bson.M{"a": 1}, m)
239+
}
240240

241241
func TestPtrInline(t *testing.T) {
242242
cases := []struct {
@@ -1354,12 +1354,12 @@ var twoWayCrossItems = []crossTypeItem{
13541354
// {&struct{ S string }{"ghi"}, &struct{ S primitive.Symbol }{"ghi"}},
13551355

13561356
// map <=> struct
1357-
// {&struct {
1358-
// A struct {
1359-
// B, C int
1360-
// }
1361-
// }{struct{ B, C int }{1, 2}},
1362-
// map[string]map[string]int{"a": {"b": 1, "c": 2}}},
1357+
{&struct {
1358+
A struct {
1359+
B, C int
1360+
}
1361+
}{struct{ B, C int }{1, 2}},
1362+
&map[string]map[string]int{"a": {"b": 1, "c": 2}}},
13631363

13641364
// {&struct{ A primitive.Symbol }{"abc"}, &map[string]string{"a": "abc"}},
13651365
// {&struct{ A primitive.Symbol }{"abc"}, &map[string][]byte{"a": []byte("abc")}},
@@ -1414,7 +1414,7 @@ var twoWayCrossItems = []crossTypeItem{
14141414
{&condTime{}, &map[string]string{}},
14151415

14161416
{&condStruct{struct{ A []int }{[]int{1}}}, &bson.M{"v": bson.M{"a": []interface{}{1}}}},
1417-
// {&condStruct{struct{ A []int }{}}, &bson.M{}},
1417+
{&condStruct{struct{ A []int }{}}, &bson.M{}},
14181418

14191419
// {&condRaw{bson.RawValue{Type: 0x0A, Value: []byte{}}},&bson.M{"v": nil}},
14201420
// {&condRaw{bson.RawValue{Type: 0x00}}, &bson.M{}},
@@ -1508,9 +1508,9 @@ var oneWayCrossItems = []crossTypeItem{
15081508
{&shortIface{int64(1) << 30}, &map[string]interface{}{"v": 1 << 30}},
15091509

15101510
// Ensure omitempty on struct with private fields works properly.
1511-
// {&struct {
1512-
// V struct{ v time.Time } `bson:",omitempty"`
1513-
// }{}, &map[string]interface{}{}},
1511+
{&struct {
1512+
V struct{ v time.Time } `bson:",omitempty"`
1513+
}{}, &map[string]interface{}{}},
15141514

15151515
// Attempt to marshal slice into RawD (issue #120).
15161516
// {bson.M{"x": []int{1, 2, 3}}, &struct{ X bson.Raw }{}},

bson/mgocompat/registry.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ func newRegistryBuilder() *bsoncodec.RegistryBuilder {
3838
bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)
3939
bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb)
4040

41+
structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.DefaultStructTagParser,
42+
bsonoptions.StructCodec().SetDecodeZeroStruct(true).SetDecodeDeepZeroInline(true).SetEncodeOmitDefaultStruct(true))
43+
4144
rb.RegisterDefaultDecoder(reflect.String, bsoncodec.NewStringCodec(bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false))).
45+
RegisterDefaultEncoder(reflect.Struct, structcodec).
46+
RegisterDefaultDecoder(reflect.Struct, structcodec).
4247
RegisterTypeMapEntry(bsontype.Int32, tInt).
4348
RegisterTypeMapEntry(bsontype.Type(0), tM).
4449
RegisterTypeMapEntry(bsontype.Binary, tByteSlice).

0 commit comments

Comments
 (0)