Skip to content

Commit 85a8e36

Browse files
author
iwysiu
authored
GODRIVER-1294 support marshaling interfaces to bson (#206)
1 parent e82d777 commit 85a8e36

File tree

3 files changed

+116
-16
lines changed

3 files changed

+116
-16
lines changed

bson/bsoncodec/default_value_encoders.go

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,9 @@ func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw bsonrw.Value
289289
// struct codec.
290290
func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
291291

292-
encoder, err := ec.LookupEncoder(val.Type().Elem())
293-
if err != nil {
292+
elemType := val.Type().Elem()
293+
encoder, err := ec.LookupEncoder(elemType)
294+
if err != nil && elemType.Kind() != reflect.Interface {
294295
return err
295296
}
296297

@@ -299,19 +300,25 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum
299300
if collisionFn != nil && collisionFn(key.String()) {
300301
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
301302
}
303+
304+
currEncoder, currVal, err := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key))
305+
if err != nil {
306+
return err
307+
}
308+
302309
vw, err := dw.WriteDocumentElement(key.String())
303310
if err != nil {
304311
return err
305312
}
306313

307-
if enc, ok := encoder.(ValueEncoder); ok {
308-
err = enc.EncodeValue(ec, vw, val.MapIndex(key))
314+
if enc, ok := currEncoder.(ValueEncoder); ok {
315+
err = enc.EncodeValue(ec, vw, currVal)
309316
if err != nil {
310317
return err
311318
}
312319
continue
313320
}
314-
err = encoder.EncodeValue(ec, vw, val.MapIndex(key))
321+
err = encoder.EncodeValue(ec, vw, currVal)
315322
if err != nil {
316323
return err
317324
}
@@ -349,18 +356,24 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val
349356
return err
350357
}
351358

352-
encoder, err := ec.LookupEncoder(val.Type().Elem())
353-
if err != nil {
359+
elemType := val.Type().Elem()
360+
encoder, err := ec.LookupEncoder(elemType)
361+
if err != nil && elemType.Kind() != reflect.Interface {
354362
return err
355363
}
356364

357365
for idx := 0; idx < val.Len(); idx++ {
366+
currEncoder, currVal, err := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
367+
if err != nil {
368+
return err
369+
}
370+
358371
vw, err := aw.WriteArrayElement()
359372
if err != nil {
360373
return err
361374
}
362375

363-
err = encoder.EncodeValue(ec, vw, val.Index(idx))
376+
err = currEncoder.EncodeValue(ec, vw, currVal)
364377
if err != nil {
365378
return err
366379
}
@@ -402,25 +415,44 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val
402415
return err
403416
}
404417

405-
encoder, err := ec.LookupEncoder(val.Type().Elem())
406-
if err != nil {
418+
elemType := val.Type().Elem()
419+
encoder, err := ec.LookupEncoder(elemType)
420+
if err != nil && elemType.Kind() != reflect.Interface {
407421
return err
408422
}
409423

410424
for idx := 0; idx < val.Len(); idx++ {
425+
currEncoder, currVal, err := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
426+
if err != nil {
427+
return err
428+
}
429+
411430
vw, err := aw.WriteArrayElement()
412431
if err != nil {
413432
return err
414433
}
415434

416-
err = encoder.EncodeValue(ec, vw, val.Index(idx))
435+
err = currEncoder.EncodeValue(ec, vw, currVal)
417436
if err != nil {
418437
return err
419438
}
420439
}
421440
return aw.WriteArrayEnd()
422441
}
423442

443+
func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
444+
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
445+
return origEncoder, currVal, nil
446+
}
447+
currVal = currVal.Elem()
448+
if !currVal.IsValid() {
449+
return nil, currVal, fmt.Errorf("cannot encode invalid element")
450+
}
451+
currEncoder, err := ec.LookupEncoder(currVal.Type())
452+
453+
return currEncoder, currVal, err
454+
}
455+
424456
// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}.
425457
func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
426458
if !val.IsValid() || val.Type() != tEmpty {

bson/bsoncodec/default_value_encoders_test.go

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ import (
2424
"math"
2525
)
2626

27+
type myInterface interface {
28+
Foo() int
29+
}
30+
31+
type myStruct struct {
32+
Val int
33+
}
34+
35+
func (ms myStruct) Foo() int {
36+
return ms.Val
37+
}
38+
2739
func TestDefaultValueEncoders(t *testing.T) {
2840
var dve DefaultValueEncoders
2941
var wrong = func(string, string) string { return "wrong" }
@@ -237,11 +249,11 @@ func TestDefaultValueEncoders(t *testing.T) {
237249
},
238250
{
239251
"Lookup Error",
240-
map[string]interface{}{},
252+
map[string]interface{}{"foo": nil},
241253
&EncodeContext{Registry: NewRegistryBuilder().Build()},
242254
&bsonrwtest.ValueReaderWriter{},
243255
bsonrwtest.WriteDocument,
244-
ErrNoEncoder{Type: reflect.TypeOf((*interface{})(nil)).Elem()},
256+
fmt.Errorf("cannot encode invalid element"),
245257
},
246258
{
247259
"WriteDocumentElement Error",
@@ -259,6 +271,22 @@ func TestDefaultValueEncoders(t *testing.T) {
259271
bsonrwtest.WriteString,
260272
errors.New("ev error"),
261273
},
274+
{
275+
"empty map/success",
276+
map[string]interface{}{},
277+
&EncodeContext{Registry: NewRegistryBuilder().Build()},
278+
&bsonrwtest.ValueReaderWriter{},
279+
bsonrwtest.WriteDocumentEnd,
280+
nil,
281+
},
282+
{
283+
"with interface/success",
284+
map[string]myInterface{"foo": myStruct{1}},
285+
&EncodeContext{Registry: buildDefaultRegistry()},
286+
nil,
287+
bsonrwtest.WriteDocumentEnd,
288+
nil,
289+
},
262290
},
263291
},
264292
{
@@ -287,7 +315,7 @@ func TestDefaultValueEncoders(t *testing.T) {
287315
&EncodeContext{Registry: NewRegistryBuilder().Build()},
288316
&bsonrwtest.ValueReaderWriter{},
289317
bsonrwtest.WriteArray,
290-
ErrNoEncoder{Type: reflect.TypeOf((*interface{})(nil)).Elem()},
318+
fmt.Errorf("cannot encode invalid element"),
291319
},
292320
{
293321
"WriteArrayElement Error",
@@ -321,6 +349,14 @@ func TestDefaultValueEncoders(t *testing.T) {
321349
bsonrwtest.WriteDocumentEnd,
322350
nil,
323351
},
352+
{
353+
"[1]interface/success",
354+
[1]myInterface{myStruct{1}},
355+
&EncodeContext{Registry: buildDefaultRegistry()},
356+
nil,
357+
bsonrwtest.WriteArrayEnd,
358+
nil,
359+
},
324360
},
325361
},
326362
{
@@ -345,11 +381,11 @@ func TestDefaultValueEncoders(t *testing.T) {
345381
},
346382
{
347383
"Lookup Error",
348-
[]interface{}{},
384+
[]interface{}{nil},
349385
&EncodeContext{Registry: NewRegistryBuilder().Build()},
350386
&bsonrwtest.ValueReaderWriter{},
351387
bsonrwtest.WriteArray,
352-
ErrNoEncoder{Type: reflect.TypeOf((*interface{})(nil)).Elem()},
388+
fmt.Errorf("cannot encode invalid element"),
353389
},
354390
{
355391
"WriteArrayElement Error",
@@ -383,6 +419,22 @@ func TestDefaultValueEncoders(t *testing.T) {
383419
bsonrwtest.WriteDocumentEnd,
384420
nil,
385421
},
422+
{
423+
"empty slice/success",
424+
[]interface{}{},
425+
&EncodeContext{Registry: NewRegistryBuilder().Build()},
426+
&bsonrwtest.ValueReaderWriter{},
427+
bsonrwtest.WriteArrayEnd,
428+
nil,
429+
},
430+
{
431+
"interface/success",
432+
[]myInterface{myStruct{1}},
433+
&EncodeContext{Registry: buildDefaultRegistry()},
434+
nil,
435+
bsonrwtest.WriteArrayEnd,
436+
nil,
437+
},
386438
},
387439
},
388440
{
@@ -871,6 +923,20 @@ func TestDefaultValueEncoders(t *testing.T) {
871923
},
872924
},
873925
},
926+
{
927+
"StructEncodeValue",
928+
defaultStructCodec,
929+
[]subtest{
930+
{
931+
"interface value",
932+
struct{ Foo myInterface }{Foo: myStruct{1}},
933+
&EncodeContext{Registry: buildDefaultRegistry()},
934+
nil,
935+
bsonrwtest.WriteDocumentEnd,
936+
nil,
937+
},
938+
},
939+
},
874940
{
875941
"CodeWithScopeEncodeValue",
876942
ValueEncoderFunc(dve.CodeWithScopeEncodeValue),

bson/bsoncodec/struct_codec.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r
7474
rv = val.FieldByIndex(desc.inline)
7575
}
7676

77+
desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv)
78+
7779
if desc.encoder == nil {
7880
return ErrNoEncoder{Type: rv.Type()}
7981
}

0 commit comments

Comments
 (0)