Skip to content

Commit 237abe9

Browse files
author
iwysiu
committed
GODRIVER-1576 fix non-addr type with pointer implementation for bson (#379)
1 parent a426bf4 commit 237abe9

File tree

4 files changed

+216
-18
lines changed

4 files changed

+216
-18
lines changed

bson/bsoncodec/cond_addr_codec.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bsoncodec
8+
9+
import (
10+
"reflect"
11+
12+
"go.mongodb.org/mongo-driver/bson/bsonrw"
13+
)
14+
15+
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
16+
type condAddrEncoder struct {
17+
canAddrEnc ValueEncoder
18+
elseEnc ValueEncoder
19+
}
20+
21+
var _ ValueEncoder = (*condAddrEncoder)(nil)
22+
23+
// newCondAddrEncoder returns an condAddrEncoder.
24+
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
25+
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
26+
return &encoder
27+
}
28+
29+
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
30+
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
31+
if val.CanAddr() {
32+
return cae.canAddrEnc.EncodeValue(ec, vw, val)
33+
}
34+
if cae.elseEnc != nil {
35+
return cae.elseEnc.EncodeValue(ec, vw, val)
36+
}
37+
return ErrNoEncoder{Type: val.Type()}
38+
}
39+
40+
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
41+
type condAddrDecoder struct {
42+
canAddrDec ValueDecoder
43+
elseDec ValueDecoder
44+
}
45+
46+
var _ ValueDecoder = (*condAddrDecoder)(nil)
47+
48+
// newCondAddrDecoder returns an CondAddrDecoder.
49+
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
50+
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
51+
return &decoder
52+
}
53+
54+
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
55+
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
56+
if val.CanAddr() {
57+
return cad.canAddrDec.DecodeValue(dc, vr, val)
58+
}
59+
if cad.elseDec != nil {
60+
return cad.elseDec.DecodeValue(dc, vr, val)
61+
}
62+
return ErrNoDecoder{Type: val.Type()}
63+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bsoncodec
8+
9+
import (
10+
"reflect"
11+
"testing"
12+
13+
"go.mongodb.org/mongo-driver/bson/bsonrw"
14+
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
15+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
16+
)
17+
18+
func TestCondAddrCodec(t *testing.T) {
19+
var inner int
20+
canAddrVal := reflect.ValueOf(&inner)
21+
addressable := canAddrVal.Elem()
22+
unaddressable := reflect.ValueOf(inner)
23+
rw := &bsonrwtest.ValueReaderWriter{}
24+
25+
t.Run("addressEncode", func(t *testing.T) {
26+
invoked := 0
27+
encode1 := func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
28+
invoked = 1
29+
return nil
30+
}
31+
encode2 := func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
32+
invoked = 2
33+
return nil
34+
}
35+
condEncoder := newCondAddrEncoder(ValueEncoderFunc(encode1), ValueEncoderFunc(encode2))
36+
37+
testCases := []struct {
38+
name string
39+
val reflect.Value
40+
invoked int
41+
}{
42+
{"canAddr", addressable, 1},
43+
{"else", unaddressable, 2},
44+
}
45+
for _, tc := range testCases {
46+
t.Run(tc.name, func(t *testing.T) {
47+
err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val)
48+
assert.Nil(t, err, "CondAddrEncoder error: %v", err)
49+
50+
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
51+
})
52+
}
53+
54+
t.Run("error", func(t *testing.T) {
55+
errEncoder := newCondAddrEncoder(ValueEncoderFunc(encode1), nil)
56+
err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable)
57+
want := ErrNoEncoder{Type: unaddressable.Type()}
58+
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
59+
})
60+
})
61+
t.Run("addressDecode", func(t *testing.T) {
62+
invoked := 0
63+
decode1 := func(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
64+
invoked = 1
65+
return nil
66+
}
67+
decode2 := func(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
68+
invoked = 2
69+
return nil
70+
}
71+
condDecoder := newCondAddrDecoder(ValueDecoderFunc(decode1), ValueDecoderFunc(decode2))
72+
73+
testCases := []struct {
74+
name string
75+
val reflect.Value
76+
invoked int
77+
}{
78+
{"canAddr", addressable, 1},
79+
{"else", unaddressable, 2},
80+
}
81+
for _, tc := range testCases {
82+
t.Run(tc.name, func(t *testing.T) {
83+
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
84+
assert.Nil(t, err, "CondAddrDecoder error: %v", err)
85+
86+
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
87+
})
88+
}
89+
90+
t.Run("error", func(t *testing.T) {
91+
errDecoder := newCondAddrDecoder(ValueDecoderFunc(decode1), nil)
92+
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
93+
want := ErrNoDecoder{Type: unaddressable.Type()}
94+
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
95+
})
96+
})
97+
}

bson/bsoncodec/registry.go

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) {
325325
return enc, nil
326326
}
327327

328-
enc, found = r.lookupInterfaceEncoder(t)
328+
enc, found = r.lookupInterfaceEncoder(t, true)
329329
if found {
330330
r.mu.Lock()
331331
r.typeEncoders[t] = enc
@@ -359,14 +359,23 @@ func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) {
359359
return enc, found
360360
}
361361

362-
func (r *Registry) lookupInterfaceEncoder(t reflect.Type) (ValueEncoder, bool) {
362+
func (r *Registry) lookupInterfaceEncoder(t reflect.Type, allowAddr bool) (ValueEncoder, bool) {
363363
if t == nil {
364364
return nil, false
365365
}
366366
for _, ienc := range r.interfaceEncoders {
367-
if t.Implements(ienc.i) || reflect.PtrTo(t).Implements(ienc.i) {
367+
if t.Implements(ienc.i) {
368368
return ienc.ve, true
369369
}
370+
if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(ienc.i) {
371+
// if *t implements an interface, this will catch if t implements an interface further ahead
372+
// in interfaceEncoders
373+
defaultEnc, found := r.lookupInterfaceEncoder(t, false)
374+
if !found {
375+
defaultEnc, _ = r.kindEncoders[t.Kind()]
376+
}
377+
return newCondAddrEncoder(ienc.ve, defaultEnc), true
378+
}
370379
}
371380
return nil, false
372381
}
@@ -397,7 +406,7 @@ func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) {
397406
return dec, nil
398407
}
399408

400-
dec, found = r.lookupInterfaceDecoder(t)
409+
dec, found = r.lookupInterfaceDecoder(t, true)
401410
if found {
402411
r.mu.Lock()
403412
r.typeDecoders[t] = dec
@@ -424,13 +433,20 @@ func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) {
424433
return dec, found
425434
}
426435

427-
func (r *Registry) lookupInterfaceDecoder(t reflect.Type) (ValueDecoder, bool) {
436+
func (r *Registry) lookupInterfaceDecoder(t reflect.Type, allowAddr bool) (ValueDecoder, bool) {
428437
for _, idec := range r.interfaceDecoders {
429-
if !t.Implements(idec.i) && !reflect.PtrTo(t).Implements(idec.i) {
430-
continue
438+
if t.Implements(idec.i) {
439+
return idec.vd, true
440+
}
441+
if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(idec.i) {
442+
// if *t implements an interface, this will catch if t implements an interface further ahead
443+
// in interfaceDecoders
444+
defaultDec, found := r.lookupInterfaceDecoder(t, false)
445+
if !found {
446+
defaultDec, _ = r.kindDecoders[t.Kind()]
447+
}
448+
return newCondAddrDecoder(idec.vd, defaultDec), true
431449
}
432-
433-
return idec.vd, true
434450
}
435451
return nil, false
436452
}

bson/bsoncodec/registry_test.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/google/go-cmp/cmp"
1414
"go.mongodb.org/mongo-driver/bson/bsonrw"
1515
"go.mongodb.org/mongo-driver/bson/bsontype"
16+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1617
)
1718

1819
func TestRegistry(t *testing.T) {
@@ -251,15 +252,6 @@ func TestRegistry(t *testing.T) {
251252
nil,
252253
false,
253254
},
254-
{
255-
// lookup a type whose pointer implements an interface and expect that the registered hook is
256-
// returned
257-
"interface implementation with hook (pointer)",
258-
ti3Impl,
259-
fc3,
260-
nil,
261-
false,
262-
},
263255
{
264256
// lookup a pointer to a type where the pointer implements an interface and expect that the
265257
// registered hook is returned
@@ -351,6 +343,36 @@ func TestRegistry(t *testing.T) {
351343
})
352344
})
353345
}
346+
// lookup a type whose pointer implements an interface and expect that the registered hook is
347+
// returned
348+
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
349+
t.Run("Encoder", func(t *testing.T) {
350+
gotEnc, err := reg.LookupEncoder(ti3Impl)
351+
assert.Nil(t, err, "LookupEncoder error: %v", err)
352+
353+
cae, ok := gotEnc.(*condAddrEncoder)
354+
assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
355+
if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
356+
t.Errorf("expected canAddrEnc %v, got %v", cae.canAddrEnc, fc3)
357+
}
358+
if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
359+
t.Errorf("expected elseEnc %v, got %v", cae.elseEnc, fsc)
360+
}
361+
})
362+
t.Run("Decoder", func(t *testing.T) {
363+
gotDec, err := reg.LookupDecoder(ti3Impl)
364+
assert.Nil(t, err, "LookupDecoder error: %v", err)
365+
366+
cad, ok := gotDec.(*condAddrDecoder)
367+
assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
368+
if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
369+
t.Errorf("expected canAddrDec %v, got %v", cad.canAddrDec, fc3)
370+
}
371+
if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
372+
t.Errorf("expected elseDec %v, got %v", cad.elseDec, fsc)
373+
}
374+
})
375+
})
354376
})
355377
})
356378
t.Run("Type Map", func(t *testing.T) {

0 commit comments

Comments
 (0)