Skip to content

Commit 788afa3

Browse files
author
Isabella Siu
committed
GODRIVER-650 Add MarshalBSON and UnmarshalBSON to DefaultValueEncoders and DefaultValueDecoders
Change-Id: Ie989f65078a48062a8109ee118d389e1a70c490a
1 parent c091973 commit 788afa3

File tree

5 files changed

+134
-0
lines changed

5 files changed

+134
-0
lines changed

bson/bsoncodec/default_value_decoders.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) {
4848
RegisterDecoder(reflect.PtrTo(tJSONNumber), ValueDecoderFunc(dvd.JSONNumberDecodeValue)).
4949
RegisterDecoder(reflect.PtrTo(tURL), ValueDecoderFunc(dvd.URLDecodeValue)).
5050
RegisterDecoder(tValueUnmarshaler, ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue)).
51+
RegisterDecoder(tUnmarshaler, ValueDecoderFunc(dvd.UnmarshalerDecodeValue)).
5152
RegisterDefaultDecoder(reflect.Bool, ValueDecoderFunc(dvd.BooleanDecodeValue)).
5253
RegisterDefaultDecoder(reflect.Int, ValueDecoderFunc(dvd.IntDecodeValue)).
5354
RegisterDefaultDecoder(reflect.Int8, ValueDecoderFunc(dvd.IntDecodeValue)).
@@ -735,6 +736,32 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(dc DecodeContext, vr
735736
return valueUnmarshaler.UnmarshalBSONValue(t, src)
736737
}
737738

739+
// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations.
740+
func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, i interface{}) error {
741+
val := reflect.ValueOf(i)
742+
var unmarshaler Unmarshaler
743+
if val.Kind() == reflect.Ptr && val.IsNil() {
744+
return fmt.Errorf("UnmarshalerDecodeValue can only unmarshal into non-nil Unmarshaler values, got %T", i)
745+
}
746+
if val.Type().Implements(tUnmarshaler) {
747+
unmarshaler = val.Interface().(Unmarshaler)
748+
} else if val.Type().Kind() == reflect.Ptr && val.Elem().Type().Implements(tUnmarshaler) {
749+
if val.Elem().Kind() == reflect.Ptr && val.Elem().IsNil() {
750+
val.Elem().Set(reflect.New(val.Type().Elem().Elem()))
751+
}
752+
unmarshaler = val.Elem().Interface().(Unmarshaler)
753+
} else {
754+
return fmt.Errorf("UnmarshalerDecodeValue can only handle types or pointers to types that are a Unmarshaler, got %T", i)
755+
}
756+
757+
_, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
758+
if err != nil {
759+
return err
760+
}
761+
762+
return unmarshaler.UnmarshalBSON(src)
763+
}
764+
738765
// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}.
739766
func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, i interface{}) error {
740767
target, ok := i.(*interface{})

bson/bsoncodec/default_value_decoders_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ func TestDefaultValueDecoders(t *testing.T) {
5252
now := time.Now().Truncate(time.Millisecond)
5353
d128 := decimal.NewDecimal128(12345, 67890)
5454
var ptrPtrValueUnmarshaler **testValueUnmarshaler
55+
var ptrPtrUnmarshaler **testUnmarshaler
5556

5657
type subtest struct {
5758
name string
@@ -1099,6 +1100,36 @@ func TestDefaultValueDecoders(t *testing.T) {
10991100
},
11001101
},
11011102
},
1103+
{
1104+
"UnmarshalerDecodeValue",
1105+
ValueDecoderFunc(dvd.UnmarshalerDecodeValue),
1106+
[]subtest{
1107+
{
1108+
"wrong type",
1109+
wrong,
1110+
nil,
1111+
nil,
1112+
bsonrwtest.Nothing,
1113+
fmt.Errorf("UnmarshalerDecodeValue can only handle types or pointers to types that are a Unmarshaler, got %T", &wrong),
1114+
},
1115+
{
1116+
"Unmarshaler",
1117+
testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
1118+
nil,
1119+
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
1120+
bsonrwtest.ReadDouble,
1121+
nil,
1122+
},
1123+
{
1124+
"nil pointer to Unmarshaler",
1125+
ptrPtrUnmarshaler,
1126+
nil,
1127+
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
1128+
bsonrwtest.Nothing,
1129+
fmt.Errorf("UnmarshalerDecodeValue can only unmarshal into non-nil Unmarshaler values, got %T", ptrPtrUnmarshaler),
1130+
},
1131+
},
1132+
},
11021133
}
11031134

11041135
for _, tc := range testCases {
@@ -1698,6 +1729,17 @@ func (tvu *testValueUnmarshaler) UnmarshalBSONValue(t bsontype.Type, val []byte)
16981729
tvu.t, tvu.val = t, val
16991730
return tvu.err
17001731
}
1732+
1733+
type testUnmarshaler struct {
1734+
Val []byte
1735+
Err error
1736+
}
1737+
1738+
func (tvu *testUnmarshaler) UnmarshalBSON(val []byte) error {
1739+
tvu.Val = val
1740+
return tvu.Err
1741+
}
1742+
17011743
func (tvu testValueUnmarshaler) Equal(tvu2 testValueUnmarshaler) bool {
17021744
return tvu.t == tvu2.t && bytes.Equal(tvu.val, tvu2.val)
17031745
}

bson/bsoncodec/default_value_encoders.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/mongodb/mongo-go-driver/bson/bsonrw"
19+
"github.com/mongodb/mongo-go-driver/bson/bsontype"
1920
"github.com/mongodb/mongo-go-driver/bson/decimal"
2021
"github.com/mongodb/mongo-go-driver/bson/objectid"
2122
)
@@ -41,6 +42,7 @@ func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) {
4142
RegisterEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)).
4243
RegisterEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)).
4344
RegisterEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)).
45+
RegisterEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)).
4446
RegisterEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)).
4547
RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)).
4648
RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)).
@@ -468,6 +470,24 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw b
468470
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, val)
469471
}
470472

473+
// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
474+
func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, i interface{}) error {
475+
vm, ok := i.(Marshaler)
476+
if !ok {
477+
return ValueEncoderError{
478+
Name: "MarshalerEncodeValue",
479+
Types: []interface{}{(Marshaler)(nil)},
480+
Received: i,
481+
}
482+
}
483+
484+
val, err := vm.MarshalBSON()
485+
if err != nil {
486+
return err
487+
}
488+
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, val)
489+
}
490+
471491
// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations.
472492
func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, i interface{}) error {
473493
proxy, ok := i.(Proxy)

bson/bsoncodec/default_value_encoders_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,40 @@ func TestDefaultValueEncoders(t *testing.T) {
472472
},
473473
},
474474
},
475+
{
476+
"MarshalerEncodeValue",
477+
ValueEncoderFunc(dve.MarshalerEncodeValue),
478+
[]subtest{
479+
{
480+
"wrong type",
481+
wrong,
482+
nil,
483+
nil,
484+
bsonrwtest.Nothing,
485+
ValueEncoderError{
486+
Name: "MarshalerEncodeValue",
487+
Types: []interface{}{(ValueMarshaler)(nil)},
488+
Received: wrong,
489+
},
490+
},
491+
{
492+
"MarshalBSON error",
493+
testMarshaler{err: errors.New("mbson error")},
494+
nil,
495+
nil,
496+
bsonrwtest.Nothing,
497+
errors.New("mbson error"),
498+
},
499+
{
500+
"success",
501+
testMarshaler{buf: bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))},
502+
nil,
503+
nil,
504+
bsonrwtest.WriteDocumentEnd,
505+
nil,
506+
},
507+
},
508+
},
475509
{
476510
"ProxyEncodeValue",
477511
ValueEncoderFunc(dve.ProxyEncodeValue),
@@ -998,6 +1032,15 @@ func (tvm testValueMarshaler) MarshalBSONValue() (bsontype.Type, []byte, error)
9981032
return tvm.t, tvm.buf, tvm.err
9991033
}
10001034

1035+
type testMarshaler struct {
1036+
buf []byte
1037+
err error
1038+
}
1039+
1040+
func (tvm testMarshaler) MarshalBSON() ([]byte, error) {
1041+
return tvm.buf, tvm.err
1042+
}
1043+
10011044
type testProxy struct {
10021045
ret interface{}
10031046
err error

bson/bsoncodec/types.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,6 @@ var tJSONNumber = reflect.TypeOf(json.Number(""))
5757

5858
var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem()
5959
var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
60+
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
61+
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
6062
var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem()

0 commit comments

Comments
 (0)