Skip to content

Commit 8857a04

Browse files
charlieviethqingyang-hu
authored andcommitted
GODRIVER-2887 Remove use of reflect.Value.MethodByName in bson (#1308)
1 parent 9318bc2 commit 8857a04

File tree

4 files changed

+76
-49
lines changed

4 files changed

+76
-49
lines changed

bson/bsoncodec/default_value_decoders.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,12 +1540,12 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr
15401540
return err
15411541
}
15421542

1543-
fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue")
1544-
errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0]
1545-
if !errVal.IsNil() {
1546-
return errVal.Interface().(error)
1543+
m, ok := val.Interface().(ValueUnmarshaler)
1544+
if !ok {
1545+
// NB: this error should be unreachable due to the above checks
1546+
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
15471547
}
1548-
return nil
1548+
return m.UnmarshalBSONValue(t, src)
15491549
}
15501550

15511551
// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations.
@@ -1588,12 +1588,12 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr
15881588
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
15891589
}
15901590

1591-
fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON")
1592-
errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0]
1593-
if !errVal.IsNil() {
1594-
return errVal.Interface().(error)
1591+
m, ok := val.Interface().(Unmarshaler)
1592+
if !ok {
1593+
// NB: this error should be unreachable due to the above checks
1594+
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
15951595
}
1596-
return nil
1596+
return m.UnmarshalBSON(src)
15971597
}
15981598

15991599
// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}.

bson/bsoncodec/default_value_decoders_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1530,13 +1530,22 @@ func TestDefaultValueDecoders(t *testing.T) {
15301530
errors.New("copy error"),
15311531
},
15321532
{
1533-
"Unmarshaler",
1533+
// Only the pointer form of testUnmarshaler implements Unmarshaler
1534+
"value does not implement Unmarshaler",
15341535
testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
15351536
nil,
15361537
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
15371538
bsonrwtest.ReadDouble,
15381539
nil,
15391540
},
1541+
{
1542+
"Unmarshaler",
1543+
&testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
1544+
nil,
1545+
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
1546+
bsonrwtest.ReadDouble,
1547+
nil,
1548+
},
15401549
},
15411550
},
15421551
{

bson/bsoncodec/default_value_encoders.go

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -564,12 +564,14 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bs
564564
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
565565
}
566566

567-
fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue")
568-
returns := fn.Call(nil)
569-
if !returns[2].IsNil() {
570-
return returns[2].Interface().(error)
567+
m, ok := val.Interface().(ValueMarshaler)
568+
if !ok {
569+
return vw.WriteNull()
570+
}
571+
t, data, err := m.MarshalBSONValue()
572+
if err != nil {
573+
return err
571574
}
572-
t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte)
573575
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data)
574576
}
575577

@@ -593,12 +595,14 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw.
593595
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
594596
}
595597

596-
fn := val.Convert(tMarshaler).MethodByName("MarshalBSON")
597-
returns := fn.Call(nil)
598-
if !returns[1].IsNil() {
599-
return returns[1].Interface().(error)
598+
m, ok := val.Interface().(Marshaler)
599+
if !ok {
600+
return vw.WriteNull()
601+
}
602+
data, err := m.MarshalBSON()
603+
if err != nil {
604+
return err
600605
}
601-
data := returns[0].Interface().([]byte)
602606
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data)
603607
}
604608

@@ -622,23 +626,31 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.Val
622626
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
623627
}
624628

625-
fn := val.Convert(tProxy).MethodByName("ProxyBSON")
626-
returns := fn.Call(nil)
627-
if !returns[1].IsNil() {
628-
return returns[1].Interface().(error)
629+
m, ok := val.Interface().(Proxy)
630+
if !ok {
631+
return vw.WriteNull()
632+
}
633+
v, err := m.ProxyBSON()
634+
if err != nil {
635+
return err
636+
}
637+
if v == nil {
638+
encoder, err := ec.LookupEncoder(nil)
639+
if err != nil {
640+
return err
641+
}
642+
return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil))
629643
}
630-
data := returns[0]
631-
var encoder ValueEncoder
632-
var err error
633-
if data.Elem().IsValid() {
634-
encoder, err = ec.LookupEncoder(data.Elem().Type())
635-
} else {
636-
encoder, err = ec.LookupEncoder(nil)
644+
vv := reflect.ValueOf(v)
645+
switch vv.Kind() {
646+
case reflect.Ptr, reflect.Interface:
647+
vv = vv.Elem()
637648
}
649+
encoder, err := ec.LookupEncoder(vv.Type())
638650
if err != nil {
639651
return err
640652
}
641-
return encoder.EncodeValue(ec, vw, data.Elem())
653+
return encoder.EncodeValue(ec, vw, vv)
642654
}
643655

644656
// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type.

bson/mgocompat/setter_getter.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package mgocompat
88

99
import (
10+
"errors"
1011
"reflect"
1112

1213
"go.mongodb.org/mongo-driver/bson"
@@ -73,16 +74,15 @@ func SetterDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val ref
7374
return err
7475
}
7576

76-
fn := val.Convert(tSetter).MethodByName("SetBSON")
77-
78-
errVal := fn.Call([]reflect.Value{reflect.ValueOf(bson.RawValue{Type: t, Value: src})})[0]
79-
if !errVal.IsNil() {
80-
err = errVal.Interface().(error)
81-
if err == ErrSetZero {
82-
val.Set(reflect.Zero(val.Type()))
83-
return nil
77+
m, ok := val.Interface().(Setter)
78+
if !ok {
79+
return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
80+
}
81+
if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil {
82+
if !errors.Is(err, ErrSetZero) {
83+
return err
8484
}
85-
return err
85+
val.Set(reflect.Zero(val.Type()))
8686
}
8787
return nil
8888
}
@@ -104,17 +104,23 @@ func GetterEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val re
104104
return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
105105
}
106106

107-
fn := val.Convert(tGetter).MethodByName("GetBSON")
108-
returns := fn.Call(nil)
109-
if !returns[1].IsNil() {
110-
return returns[1].Interface().(error)
107+
m, ok := val.Interface().(Getter)
108+
if !ok {
109+
return vw.WriteNull()
110+
}
111+
x, err := m.GetBSON()
112+
if err != nil {
113+
return err
114+
}
115+
if x == nil {
116+
return vw.WriteNull()
111117
}
112-
intermediate := returns[0]
113-
encoder, err := ec.Registry.LookupEncoder(intermediate.Type())
118+
vv := reflect.ValueOf(x)
119+
encoder, err := ec.Registry.LookupEncoder(vv.Type())
114120
if err != nil {
115121
return err
116122
}
117-
return encoder.EncodeValue(ec, vw, intermediate)
123+
return encoder.EncodeValue(ec, vw, vv)
118124
}
119125

120126
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type

0 commit comments

Comments
 (0)