Skip to content

Commit c548ef5

Browse files
author
iwysiu
authored
GODRIVER-1407 add MapCodec for mgocompat (#242)
1 parent 3d3d895 commit c548ef5

File tree

10 files changed

+391
-124
lines changed

10 files changed

+391
-124
lines changed

bson/bsoncodec/default_value_decoders.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) {
7676
RegisterDefaultDecoder(reflect.Float32, ValueDecoderFunc(dvd.FloatDecodeValue)).
7777
RegisterDefaultDecoder(reflect.Float64, ValueDecoderFunc(dvd.FloatDecodeValue)).
7878
RegisterDefaultDecoder(reflect.Array, ValueDecoderFunc(dvd.ArrayDecodeValue)).
79-
RegisterDefaultDecoder(reflect.Map, ValueDecoderFunc(dvd.MapDecodeValue)).
79+
RegisterDefaultDecoder(reflect.Map, defaultMapCodec).
8080
RegisterDefaultDecoder(reflect.Slice, ValueDecoderFunc(dvd.SliceDecodeValue)).
8181
RegisterDefaultDecoder(reflect.String, defaultStringCodec).
8282
RegisterDefaultDecoder(reflect.Struct, defaultStructCodec).

bson/bsoncodec/default_value_encoders.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) {
104104
RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)).
105105
RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)).
106106
RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)).
107-
RegisterDefaultEncoder(reflect.Map, ValueEncoderFunc(dve.MapEncodeValue)).
107+
RegisterDefaultEncoder(reflect.Map, defaultMapCodec).
108108
RegisterDefaultEncoder(reflect.Slice, ValueEncoderFunc(dve.SliceEncodeValue)).
109109
RegisterDefaultEncoder(reflect.String, defaultStringCodec).
110110
RegisterDefaultEncoder(reflect.Struct, defaultStructCodec).

bson/bsoncodec/map_codec.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bsoncodec
8+
9+
import (
10+
"fmt"
11+
"reflect"
12+
"strconv"
13+
14+
"go.mongodb.org/mongo-driver/bson/bsonoptions"
15+
"go.mongodb.org/mongo-driver/bson/bsonrw"
16+
"go.mongodb.org/mongo-driver/bson/bsontype"
17+
)
18+
19+
var defaultMapCodec = NewMapCodec()
20+
21+
// MapCodec is the Codec used for map values.
22+
type MapCodec struct {
23+
DecodeZerosMap bool
24+
}
25+
26+
var _ ValueCodec = &MapCodec{}
27+
28+
// NewMapCodec returns a MapCodec with options opts.
29+
func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
30+
mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
31+
32+
codec := MapCodec{}
33+
if mapOpt.DecodeZerosMap != nil {
34+
codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
35+
}
36+
return &codec
37+
}
38+
39+
// EncodeValue is the ValueEncoder for map[*]* types.
40+
func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
41+
if !val.IsValid() || val.Kind() != reflect.Map {
42+
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
43+
}
44+
45+
if val.IsNil() {
46+
// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
47+
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
48+
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
49+
// so just continue. The operations on a map reflection value are valid, so we can call
50+
// MapKeys within mapEncodeValue without a problem.
51+
err := vw.WriteNull()
52+
if err == nil {
53+
return nil
54+
}
55+
}
56+
57+
dw, err := vw.WriteDocument()
58+
if err != nil {
59+
return err
60+
}
61+
62+
return mc.mapEncodeValue(ec, dw, val, nil)
63+
}
64+
65+
// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
66+
// true if the provided key exists, this is mainly used for inline maps in the
67+
// struct codec.
68+
func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
69+
70+
elemType := val.Type().Elem()
71+
encoder, err := ec.LookupEncoder(elemType)
72+
if err != nil && elemType.Kind() != reflect.Interface {
73+
return err
74+
}
75+
76+
keys := val.MapKeys()
77+
for _, key := range keys {
78+
keyStr := fmt.Sprint(key)
79+
if collisionFn != nil && collisionFn(keyStr) {
80+
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
81+
}
82+
83+
currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
84+
if lookupErr != nil && lookupErr != errInvalidValue {
85+
return lookupErr
86+
}
87+
88+
vw, err := dw.WriteDocumentElement(key.String())
89+
if err != nil {
90+
return err
91+
}
92+
93+
if lookupErr == errInvalidValue {
94+
err = vw.WriteNull()
95+
if err != nil {
96+
return err
97+
}
98+
continue
99+
}
100+
101+
if enc, ok := currEncoder.(ValueEncoder); ok {
102+
err = enc.EncodeValue(ec, vw, currVal)
103+
if err != nil {
104+
return err
105+
}
106+
continue
107+
}
108+
err = encoder.EncodeValue(ec, vw, currVal)
109+
if err != nil {
110+
return err
111+
}
112+
}
113+
114+
return dw.WriteDocumentEnd()
115+
}
116+
117+
// DecodeValue is the ValueDecoder for map[string/decimal]* types.
118+
func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
119+
if (!val.CanSet() && val.IsNil()) || val.Kind() != reflect.Map {
120+
return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
121+
}
122+
123+
switch vr.Type() {
124+
case bsontype.Type(0), bsontype.EmbeddedDocument:
125+
case bsontype.Null:
126+
val.Set(reflect.Zero(val.Type()))
127+
return vr.ReadNull()
128+
default:
129+
return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
130+
}
131+
132+
dr, err := vr.ReadDocument()
133+
if err != nil {
134+
return err
135+
}
136+
137+
if val.IsNil() {
138+
val.Set(reflect.MakeMap(val.Type()))
139+
}
140+
141+
if val.Len() > 0 && mc.DecodeZerosMap {
142+
clearMap(val)
143+
}
144+
145+
eType := val.Type().Elem()
146+
decoder, err := dc.LookupDecoder(eType)
147+
if err != nil {
148+
return err
149+
}
150+
151+
if eType == tEmpty {
152+
dc.Ancestor = val.Type()
153+
}
154+
155+
keyType := val.Type().Key()
156+
keyKind := keyType.Kind()
157+
158+
for {
159+
key, vr, err := dr.ReadElement()
160+
if err == bsonrw.ErrEOD {
161+
break
162+
}
163+
if err != nil {
164+
return err
165+
}
166+
167+
elem := reflect.New(eType).Elem()
168+
169+
err = decoder.DecodeValue(dc, vr, elem)
170+
if err != nil {
171+
return err
172+
}
173+
174+
k := reflect.ValueOf(key)
175+
if keyType != tString {
176+
switch keyKind {
177+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
178+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
179+
reflect.Float32, reflect.Float64:
180+
parsed, err := strconv.ParseFloat(k.String(), 64)
181+
if err != nil {
182+
return fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyKind, err)
183+
}
184+
k = reflect.ValueOf(parsed)
185+
case reflect.String: // if keyType wraps string
186+
default:
187+
return fmt.Errorf("BSON map must have string or decimal keys. Got:%v", val)
188+
}
189+
190+
k = k.Convert(keyType)
191+
}
192+
193+
val.SetMapIndex(k, elem)
194+
}
195+
return nil
196+
}
197+
198+
func clearMap(m reflect.Value) {
199+
var none reflect.Value
200+
for _, k := range m.MapKeys() {
201+
m.SetMapIndex(k, none)
202+
}
203+
}

bson/bsoncodec/struct_codec.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type StructCodec struct {
3939
DecodeZeroStruct bool
4040
DecodeDeepZeroInline bool
4141
EncodeOmitDefaultStruct bool
42+
AllowUnexportedFields bool
4243
}
4344

4445
var _ ValueEncoder = &StructCodec{}
@@ -66,6 +67,9 @@ func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions)
6667
if structOpt.EncodeOmitDefaultStruct != nil {
6768
codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
6869
}
70+
if structOpt.AllowUnexportedFields != nil {
71+
codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
72+
}
6973

7074
return codec, nil
7175
}
@@ -151,7 +155,7 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r
151155
return exists
152156
}
153157

154-
return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
158+
return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn)
155159
}
156160

157161
return dw.WriteDocumentEnd()
@@ -187,9 +191,6 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r
187191
var inlineMap reflect.Value
188192
if sd.inlineMap >= 0 {
189193
inlineMap = val.Field(sd.inlineMap)
190-
if inlineMap.IsNil() {
191-
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
192-
}
193194
decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
194195
if err != nil {
195196
return err
@@ -229,6 +230,10 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r
229230
continue
230231
}
231232

233+
if inlineMap.IsNil() {
234+
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
235+
}
236+
232237
elem := reflect.New(inlineMap.Type().Elem()).Elem()
233238
err = decoder.DecodeValue(r, vr, elem)
234239
if err != nil {
@@ -361,8 +366,8 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr
361366

362367
for i := 0; i < numFields; i++ {
363368
sf := t.Field(i)
364-
if sf.PkgPath != "" {
365-
// unexported, ignore
369+
if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
370+
// field is private or unexported fields aren't allowed, ignore
366371
continue
367372
}
368373

bson/bsonoptions/map_codec_options.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bsonoptions
8+
9+
// MapCodecOptions represents all possible options for map encoding and decoding.
10+
type MapCodecOptions struct {
11+
DecodeZerosMap *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false.
12+
}
13+
14+
// MapCodec creates a new *MapCodecOptions
15+
func MapCodec() *MapCodecOptions {
16+
return &MapCodecOptions{}
17+
}
18+
19+
// SetDecodeZerosMap specifies if the map should be zeroed before decoding into it. Defaults to false.
20+
func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions {
21+
t.DecodeZerosMap = &b
22+
return t
23+
}
24+
25+
// MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion.
26+
func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions {
27+
s := MapCodec()
28+
for _, opt := range opts {
29+
if opt == nil {
30+
continue
31+
}
32+
if opt.DecodeZerosMap != nil {
33+
s.DecodeZerosMap = opt.DecodeZerosMap
34+
}
35+
}
36+
37+
return s
38+
}

bson/bsonoptions/string_codec_options.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type StringCodecOptions struct {
1515

1616
// StringCodec creates a new *StringCodecOptions
1717
func StringCodec() *StringCodecOptions {
18-
return &StringCodecOptions{&defaultDecodeOIDAsHex}
18+
return &StringCodecOptions{}
1919
}
2020

2121
// SetDecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. If false, a string made
@@ -27,7 +27,7 @@ func (t *StringCodecOptions) SetDecodeObjectIDAsHex(b bool) *StringCodecOptions
2727

2828
// MergeStringCodecOptions combines the given *StringCodecOptions into a single *StringCodecOptions in a last one wins fashion.
2929
func MergeStringCodecOptions(opts ...*StringCodecOptions) *StringCodecOptions {
30-
s := StringCodec()
30+
s := &StringCodecOptions{&defaultDecodeOIDAsHex}
3131
for _, opt := range opts {
3232
if opt == nil {
3333
continue

bson/bsonoptions/struct_codec_options.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ type StructCodecOptions struct {
1111
DecodeZeroStruct *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false.
1212
DecodeDeepZeroInline *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false.
1313
EncodeOmitDefaultStruct *bool // Specifies if default structs should be considered empty by omitempty. Defaults to false.
14+
AllowUnexportedFields *bool // Specifies if unexported fields should be marshaled/unmarshaled. Defaults to false.
1415
}
1516

1617
// StructCodec creates a new *StructCodecOptions
@@ -37,6 +38,12 @@ func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOpti
3738
return t
3839
}
3940

41+
// SetAllowUnexportedFields specifies if unexported fields should be marshaled/unmarshaled. Defaults to false.
42+
func (t *StructCodecOptions) SetAllowUnexportedFields(b bool) *StructCodecOptions {
43+
t.AllowUnexportedFields = &b
44+
return t
45+
}
46+
4047
// MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion.
4148
func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions {
4249
s := StructCodec()
@@ -54,6 +61,9 @@ func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions {
5461
if opt.EncodeOmitDefaultStruct != nil {
5562
s.EncodeOmitDefaultStruct = opt.EncodeOmitDefaultStruct
5663
}
64+
if opt.AllowUnexportedFields != nil {
65+
s.AllowUnexportedFields = opt.AllowUnexportedFields
66+
}
5767
}
5868

5969
return s

bson/decoder.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,19 @@ func (d *Decoder) Decode(val interface{}) error {
7878
}
7979

8080
rval := reflect.ValueOf(val)
81-
if rval.Kind() != reflect.Ptr {
82-
return fmt.Errorf("argument to Decode must be a pointer to a type, but got %v", rval)
83-
}
84-
if rval.IsNil() {
85-
return ErrDecodeToNil
81+
switch rval.Kind() {
82+
case reflect.Ptr:
83+
if rval.IsNil() {
84+
return ErrDecodeToNil
85+
}
86+
rval = rval.Elem()
87+
case reflect.Map:
88+
if rval.IsNil() {
89+
return ErrDecodeToNil
90+
}
91+
default:
92+
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
8693
}
87-
rval = rval.Elem()
8894
decoder, err := d.dc.LookupDecoder(rval.Type())
8995
if err != nil {
9096
return err

0 commit comments

Comments
 (0)