diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 7dab1fa5a5..2f195329ca 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -1166,6 +1166,12 @@ func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Va return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } + if vr.Type() == TypeNull { + val.Set(reflect.Zero(val.Type())) + + return vr.ReadNull() + } + if val.Kind() == reflect.Ptr && val.IsNil() { if !val.CanSet() { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} diff --git a/bson/unmarshaling_cases_test.go b/bson/unmarshaling_cases_test.go index 90e44a31be..71d22f32d6 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/unmarshaling_cases_test.go @@ -197,6 +197,25 @@ type unmarshalerNonPtrStruct struct { type myInt64 int64 +var _ ValueUnmarshaler = (*myInt64)(nil) + +func (mi *myInt64) UnmarshalBSONValue(t byte, b []byte) error { + if len(b) == 0 { + return nil + } + + if Type(t) == TypeInt64 { + i, err := newValueReader(TypeInt64, bytes.NewReader(b)).ReadInt64() + if err != nil { + return err + } + + *mi = myInt64(i) + } + + return nil +} + func (mi *myInt64) UnmarshalBSON(b []byte) error { if len(b) == 0 { return nil