diff --git a/bson/array_codec.go b/bson/array_codec.go index 4642fb6ea2..aa65803959 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -16,12 +16,12 @@ import ( type arrayCodec struct{} // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCoreArray { - return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} +func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val any) error { + arr, ok := val.(bsoncore.Array) + if !ok { + return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} } - arr := val.Interface().(bsoncore.Array) return copyArrayFromBytes(vw, arr) } diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 80e13e7d81..1a8ca03104 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -139,6 +139,25 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref return fn(ec, vw, val) } +// defaultValueEncoderFunc is an adapter function that allows a function with +// the correct signature to be used as a ValueEncoder. +type defaultValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error + +// EncodeValue implements the ValueEncoder interface. +func (fn defaultValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return fn(ec, vw, val) +} + +type reflectFreeValueEncoder interface { + EncodeValue(ec EncodeContext, vw ValueWriter, val any) error +} + +type reflectFreeValueEncoderFunc func(ec EncodeContext, vw ValueWriter, val any) error + +func (fn reflectFreeValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val any) error { + return fn(ec, vw, val) +} + // ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index bd44cf9a89..d6d27fcc86 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -12,28 +12,13 @@ import ( ) // byteSliceCodec is the Codec used for []byte values. -type byteSliceCodec struct { - // encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values - // instead of BSON null. - encodeNilAsEmpty bool -} +type byteSliceCodec struct{} // Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be // used by collection type decoders (e.g. map, slice, etc) to set individual values in a // collection. var _ typeDecoder = &byteSliceCodec{} -// EncodeValue is the ValueEncoder for []byte. -func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tByteSlice { - return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { - return vw.WriteNull() - } - return vw.WriteBinary(val.Interface().([]byte)) -} - func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ diff --git a/bson/codec_cache.go b/bson/codec_cache.go index b4042822e6..4fffe0eff1 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -58,6 +58,39 @@ func (c *typeEncoderCache) Clone() *typeEncoderCache { return cc } +type typeReflectFreeEncoderCache struct { + cache sync.Map // map[reflect.Type]typeReflectFreeEncoderCache +} + +func (c *typeReflectFreeEncoderCache) Store(rt reflect.Type, enc reflectFreeValueEncoder) { + c.cache.Store(rt, enc) +} + +func (c *typeReflectFreeEncoderCache) Load(rt reflect.Type) (reflectFreeValueEncoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(reflectFreeValueEncoder), true + } + return nil, false +} + +func (c *typeReflectFreeEncoderCache) LoadOrStore(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { + if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { + enc = v.(reflectFreeValueEncoder) + } + return enc +} + +func (c *typeReflectFreeEncoderCache) Clone() *typeReflectFreeEncoderCache { + cc := new(typeReflectFreeEncoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + type typeDecoderCache struct { cache sync.Map // map[reflect.Type]ValueDecoder } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 4dad538a26..bb44a4f894 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -3414,20 +3414,22 @@ func TestDefaultValueDecoders(t *testing.T) { // the top-level to decode to registered type when unmarshalling to interface{} topLevelReg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(topLevelReg) registerDefaultDecoders(topLevelReg) topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) embeddedReg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(embeddedReg) registerDefaultDecoders(embeddedReg) @@ -3470,10 +3472,11 @@ func TestDefaultValueDecoders(t *testing.T) { // type information is not available. reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -3564,10 +3567,11 @@ func TestDefaultValueDecoders(t *testing.T) { // Use a registry that has all default decoders with the custom interface{} decoder that always errors. nestedRegistry := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultDecoders(nestedRegistry) nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) @@ -3721,10 +3725,11 @@ func TestDefaultValueDecoders(t *testing.T) { ) reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultDecoders(reg) reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) @@ -3795,10 +3800,11 @@ func buildDocument(elems []byte) []byte { func buildDefaultRegistry() *Registry { reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index bd5a20f2f9..ccd37e7413 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -13,6 +13,7 @@ import ( "net/url" "reflect" "sync" + "time" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -59,28 +60,58 @@ func registerDefaultEncoders(reg *Registry) { mapEncoder := &mapCodec{} uintCodec := &uintCodec{} - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeEncoder(tTime, &timeCodec{}) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) - reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) + // Register the reflect-free default type encoders. + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) + reg.registerReflectFreeTypeEncoder(tTime, reflectFreeValueEncoderFunc(timeEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCoreArray, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDecimal, reflectFreeValueEncoderFunc(decimal128EncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tJSONNumber, reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tURL, reflectFreeValueEncoderFunc(urlEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tJavaScript, reflectFreeValueEncoderFunc(javaScriptEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tSymbol, reflectFreeValueEncoderFunc(symbolEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tBinary, reflectFreeValueEncoderFunc(binaryEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tVector, reflectFreeValueEncoderFunc(vectorEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tUndefined, reflectFreeValueEncoderFunc(undefinedEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDateTime, reflectFreeValueEncoderFunc(dateTimeEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tRegex, reflectFreeValueEncoderFunc(regexEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDBPointer, reflectFreeValueEncoderFunc(dbPointerEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tTimestamp, reflectFreeValueEncoderFunc(timestampEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tMinKey, reflectFreeValueEncoderFunc(minKeyEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tMaxKey, reflectFreeValueEncoderFunc(maxKeyEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCoreDocument, reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCodeWithScope, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF)) + + // Register the reflect-based default encoders. These are required since + // removing them would break Registry.LookupEncoder. However, these will + // never be used internally. + // + reg.RegisterTypeEncoder(tByteSlice, byteSliceEncodeValue(false)) + reg.RegisterTypeEncoder(tTime, defaultValueEncoderFunc(timeEncodeValue)) + reg.RegisterTypeEncoder(tEmpty, ValueEncoderFunc(emptyInterfaceValue)) + reg.RegisterTypeEncoder(tCoreArray, defaultValueEncoderFunc(coreArrayEncodeValue)) + reg.RegisterTypeEncoder(tOID, defaultValueEncoderFunc(objectIDEncodeValue)) + reg.RegisterTypeEncoder(tDecimal, defaultValueEncoderFunc(decimal128EncodeValue)) + reg.RegisterTypeEncoder(tJSONNumber, defaultValueEncoderFunc(jsonNumberEncodeValue)) + reg.RegisterTypeEncoder(tURL, defaultValueEncoderFunc(urlEncodeValue)) + reg.RegisterTypeEncoder(tJavaScript, defaultValueEncoderFunc(javaScriptEncodeValue)) + reg.RegisterTypeEncoder(tSymbol, defaultValueEncoderFunc(symbolEncodeValue)) + reg.RegisterTypeEncoder(tBinary, defaultValueEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tVector, defaultValueEncoderFunc(vectorEncodeValue)) + reg.RegisterTypeEncoder(tUndefined, defaultValueEncoderFunc(undefinedEncodeValue)) + reg.RegisterTypeEncoder(tDateTime, defaultValueEncoderFunc(dateTimeEncodeValue)) + reg.RegisterTypeEncoder(tNull, defaultValueEncoderFunc(nullEncodeValue)) + reg.RegisterTypeEncoder(tRegex, defaultValueEncoderFunc(regexEncodeValue)) + reg.RegisterTypeEncoder(tDBPointer, defaultValueEncoderFunc(dbPointerEncodeValue)) + reg.RegisterTypeEncoder(tTimestamp, defaultValueEncoderFunc(timestampEncodeValue)) + reg.RegisterTypeEncoder(tMinKey, defaultValueEncoderFunc(minKeyEncodeValue)) + reg.RegisterTypeEncoder(tMaxKey, defaultValueEncoderFunc(maxKeyEncodeValue)) + reg.RegisterTypeEncoder(tCoreDocument, defaultValueEncoderFunc(coreDocumentEncodeValue)) + reg.RegisterTypeEncoder(tCodeWithScope, defaultValueEncoderFunc(codeWithScopeEncodeValue)) + + // Register the kind-based default encoders. These must continue using + // reflection since they account for custom types that cannot be anticipated. reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) @@ -100,6 +131,8 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterKindEncoder(reflect.String, &stringCodec{}) reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder)) reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) + + // Register the interface-based default encoders. reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) } @@ -142,7 +175,21 @@ func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { } } -// floatEncodeValue is the ValueEncoderFunc for float types. +func floatEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if f32, ok := val.(float32); ok { + return vw.WriteDouble(float64(f32)) + } + + if f64, ok := val.(float64); ok { + return vw.WriteDouble(f64) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: reflect.ValueOf(val)} +} + +// floatEncodeValue is the ValueEncoderFunc for float types. this function is +// used to decode "types" and "kinds" and therefore cannot be a wrapper for +// reflection-free decoding in the default "type" case. func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: @@ -153,27 +200,43 @@ func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tOID { - return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} +func objectIDEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + objID, ok := val.(ObjectID) + if !ok { + return ValueEncoderError{ + Name: "ObjectIDEncodeValue", + Types: []reflect.Type{tOID}, + Received: reflect.ValueOf(val), + } } - return vw.WriteObjectID(val.Interface().(ObjectID)) + + return vw.WriteObjectID(objID) } -// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDecimal { - return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} +// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. +func objectIDEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return objectIDEncodeValueRF(ec, vw, val.Interface()) +} + +func decimal128EncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + d128, ok := val.(Decimal128) + if !ok { + return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: reflect.ValueOf(val)} } - return vw.WriteDecimal128(val.Interface().(Decimal128)) + + return vw.WriteDecimal128(d128) } -// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tJSONNumber { - return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} +// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. +func decimal128EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return decimal128EncodeValueRF(ec, vw, val.Interface()) +} + +func jsonNumberEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { + jsnum, ok := val.(json.Number) + if !ok { + return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: reflect.ValueOf(val)} } - jsnum := val.Interface().(json.Number) // Attempt int first, then float64 if i64, err := jsnum.Int64(); err == nil { @@ -185,18 +248,28 @@ func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) return err } - return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) + return floatEncodeValueRF(ec, vw, f64) } -// urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tURL { - return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} +// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. +func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return jsonNumberEncodeValueRF(ec, vw, val.Interface()) +} + +func urlEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + u, ok := val.(url.URL) + if !ok { + return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: reflect.ValueOf(val)} } - u := val.Interface().(url.URL) + return vw.WriteString(u.String()) } +// urlEncodeValue is the ValueEncoderFunc for url.URL. +func urlEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return urlEncodeValueRF(ec, vw, val.Interface()) +} + // arrayEncodeValue is the ValueEncoderFunc for array types. func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { @@ -337,144 +410,190 @@ func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er return copyValueFromBytes(vw, TypeEmbeddedDocument, data) } +func javaScriptEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + jsString, ok := val.(JavaScript) + if !ok { + return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: reflect.ValueOf(val)} + } + + return vw.WriteJavascript(string(jsString)) +} + // javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tJavaScript { - return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} + return javaScriptEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + +func symbolEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + symbol, ok := val.(Symbol) + if !ok { + return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: reflect.ValueOf(val)} } - return vw.WriteJavascript(val.String()) + return vw.WriteSymbol(string(symbol)) } // symbolEncodeValue is the ValueEncoderFunc for the Symbol type. func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tSymbol { - return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} + return symbolEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + +func binaryEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + b, ok := val.(Binary) + if !ok { + return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: reflect.ValueOf(val)} } - return vw.WriteSymbol(val.String()) + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } // binaryEncodeValue is the ValueEncoderFunc for Binary. func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tBinary { - return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} + return binaryEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + +func vectorEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + v, ok := val.(Vector) + if !ok { + return ValueEncoderError{Name: "VectorEncodeValue", Types: []reflect.Type{tVector}, Received: reflect.ValueOf(val)} } - b := val.Interface().(Binary) + b := v.Binary() return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } // vectorEncodeValue is the ValueEncoderFunc for Vector. func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - t := val.Type() - if !val.IsValid() || t != tVector { - return ValueEncoderError{Name: "VectorEncodeValue", - Types: []reflect.Type{tVector}, - Received: val, - } + return vectorEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + +func undefinedEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(Undefined); !ok { + return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: reflect.ValueOf(val)} } - v := val.Interface().(Vector) - b := v.Binary() - return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) + + return vw.WriteUndefined() } // undefinedEncodeValue is the ValueEncoderFunc for Undefined. func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tUndefined { - return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} + return undefinedEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + +func dateTimeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + dateTime, ok := val.(DateTime) + if !ok { + return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: reflect.ValueOf(val)} } - return vw.WriteUndefined() + return vw.WriteDateTime(int64(dateTime)) } // dateTimeEncodeValue is the ValueEncoderFunc for DateTime. func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDateTime { - return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} + return dateTimeEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + +func nullEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(Null); !ok { + return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: reflect.ValueOf(val)} } - return vw.WriteDateTime(val.Int()) + return vw.WriteNull() } // nullEncodeValue is the ValueEncoderFunc for Null. func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tNull { - return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} - } - - return vw.WriteNull() + return nullEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -// regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tRegex { - return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} +func regexEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + regex, ok := val.(Regex) + if !ok { + return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: reflect.ValueOf(val)} } - regex := val.Interface().(Regex) - return vw.WriteRegex(regex.Pattern, regex.Options) } -// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDBPointer { - return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} - } +// regexEncodeValue is the ValueEncoderFunc for Regex. +func regexEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return regexEncodeValueRF(ec, vw, val.Interface()) +} - dbp := val.Interface().(DBPointer) +func dbPointerEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + dbp, ok := val.(DBPointer) + if !ok { + return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: reflect.ValueOf(val)} + } return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTimestamp { - return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} - } +// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. +func dbPointerEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return dbPointerEncodeValueRF(ec, vw, val.Interface()) +} - ts := val.Interface().(Timestamp) +func timestampEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + ts, ok := val.(Timestamp) + if !ok { + return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: reflect.ValueOf(val)} + } return vw.WriteTimestamp(ts.T, ts.I) } -// minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tMinKey { - return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} +// timestampEncodeValue is the ValueEncoderFunc for Timestamp. +func timestampEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return timestampEncodeValueRF(ec, vw, val.Interface()) +} + +func minKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(MinKey); !ok { + return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: reflect.ValueOf(val)} } return vw.WriteMinKey() } -// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tMaxKey { - return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} +// minKeyEncodeValue is the ValueEncoderFunc for MinKey. +func minKeyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return minKeyEncodeValueRF(ec, vw, val.Interface()) +} + +func maxKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(MaxKey); !ok { + return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: reflect.ValueOf(val)} } return vw.WriteMaxKey() } -// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCoreDocument { - return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} - } +// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. +func maxKeyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return maxKeyEncodeValueRF(ec, vw, val.Interface()) +} - cdoc := val.Interface().(bsoncore.Document) +func coreDocumentEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + cdoc, ok := val.(bsoncore.Document) + if !ok { + return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: reflect.ValueOf(val)} + } return copyDocumentFromBytes(vw, cdoc) } -// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCodeWithScope { - return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} - } +// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. +func coreDocumentEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return coreDocumentEncodeValueRF(ec, vw, val.Interface()) +} - cws := val.Interface().(CodeWithScope) +func codeWithScopeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { + cws, ok := val.(CodeWithScope) + if !ok { + return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: reflect.ValueOf(val)} + } dw, err := vw.WriteCodeWithScope(string(cws.Code)) if err != nil { @@ -489,7 +608,6 @@ func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Valu scopeVW.reset(scopeVW.buf[:0]) scopeVW.w = sw defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) if err != nil { return err @@ -507,6 +625,11 @@ func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Valu return dw.WriteDocumentEnd() } +// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return codeWithScopeEncodeValueRF(ec, vw, val.Interface()) +} + // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type func isImplementationNil(val reflect.Value, inter reflect.Type) bool { vt := val.Type() @@ -515,3 +638,72 @@ func isImplementationNil(val reflect.Value, inter reflect.Type) bool { } return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() } + +func byteSliceEncodeValueRF(encodeNilAsEmpty bool) reflectFreeValueEncoderFunc { + return reflectFreeValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val any) error { + byteSlice, ok := val.([]byte) + if !ok { + return ValueEncoderError{ + Name: "ByteSliceEncodeValue", + Types: []reflect.Type{tByteSlice}, + Received: reflect.ValueOf(val), + } + } + + if byteSlice == nil && !encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + return vw.WriteNull() + } + + return vw.WriteBinary(byteSlice) + }) +} + +func byteSliceEncodeValue(encodeNilAsEmpty bool) defaultValueEncoderFunc { + return defaultValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return byteSliceEncodeValueRF(encodeNilAsEmpty)(ec, vw, val.Interface()) + }) +} + +func timeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + tt, ok := val.(time.Time) + if !ok { + return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} + } + + dt := NewDateTimeFromTime(tt) + return vw.WriteDateTime(int64(dt)) +} + +func timeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return timeEncodeValueRF(ec, vw, val.Interface()) +} + +func coreArrayEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + arr, ok := val.(bsoncore.Array) + if !ok { + return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} + } + + return copyArrayFromBytes(vw, arr) +} + +func coreArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return coreArrayEncodeValueRF(ec, vw, val.Interface()) +} + +func emptyInterfaceValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tEmpty { + return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} + } + + if val.IsNil() { + return vw.WriteNull() + } + + encoder, err := ec.LookupEncoder(val.Elem().Type()) + if err != nil { + return err + } + + return encoder.EncodeValue(ec, vw, val.Elem()) +} diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index e15019785d..70aae59f3f 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -73,11 +73,13 @@ func TestDefaultValueEncoders(t *testing.T) { testCases := []struct { name string ve ValueEncoder + rfve reflectFreeValueEncoder subtests []subtest }{ { "BooleanEncodeValue", ValueEncoderFunc(booleanEncodeValue), + nil, []subtest{ { "wrong type", @@ -94,6 +96,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "IntEncodeValue", ValueEncoderFunc(intEncodeValue), + nil, []subtest{ { "wrong type", @@ -134,6 +137,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "UintEncodeValue", &uintCodec{}, + nil, []subtest{ { "wrong type", @@ -175,6 +179,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "FloatEncodeValue", ValueEncoderFunc(floatEncodeValue), + nil, []subtest{ { "wrong type", @@ -194,9 +199,34 @@ func TestDefaultValueEncoders(t *testing.T) { {"float64/reflection path", myfloat64(3.14159), nil, nil, writeDouble, nil}, }, }, + { + "reflection free FloatEncodeValue", + nil, + reflectFreeValueEncoderFunc(floatEncodeValueRF), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "FloatEncodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: reflect.ValueOf(wrong), + }, + }, + // the reflection free encoder function should only be used for + // encoding "types", not "kinds". So the reflection path tests are not + // valid. + {"float32/fast path", float32(3.14159), nil, nil, writeDouble, nil}, + {"float64/fast path", float64(3.14159), nil, nil, writeDouble, nil}, + }, + }, { "TimeEncodeValue", - &timeCodec{}, + ValueEncoderFunc(timeEncodeValue), + reflectFreeValueEncoderFunc(timeEncodeValueRF), []subtest{ { "wrong type", @@ -212,6 +242,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MapEncodeValue", &mapCodec{}, + nil, []subtest{ { "wrong kind", @@ -292,6 +323,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ArrayEncodeValue", ValueEncoderFunc(arrayEncodeValue), + nil, []subtest{ { "wrong kind", @@ -370,6 +402,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "SliceEncodeValue", &sliceCodec{}, + nil, []subtest{ { "wrong kind", @@ -456,6 +489,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ObjectIDEncodeValue", ValueEncoderFunc(objectIDEncodeValue), + reflectFreeValueEncoderFunc(objectIDEncodeValueRF), []subtest{ { "wrong type", @@ -475,6 +509,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Decimal128EncodeValue", ValueEncoderFunc(decimal128EncodeValue), + reflectFreeValueEncoderFunc(decimal128EncodeValueRF), []subtest{ { "wrong type", @@ -490,6 +525,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "JSONNumberEncodeValue", ValueEncoderFunc(jsonNumberEncodeValue), + reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF), []subtest{ { "wrong type", @@ -519,6 +555,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "URLEncodeValue", ValueEncoderFunc(urlEncodeValue), + reflectFreeValueEncoderFunc(urlEncodeValueRF), []subtest{ { "wrong type", @@ -533,7 +570,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ByteSliceEncodeValue", - &byteSliceCodec{}, + ValueEncoderFunc(byteSliceEncodeValue(false)), + byteSliceEncodeValueRF(false), []subtest{ { "wrong type", @@ -549,7 +587,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "EmptyInterfaceEncodeValue", - &emptyInterfaceCodec{}, + ValueEncoderFunc(emptyInterfaceValue), + nil, []subtest{ { "wrong type", @@ -564,6 +603,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ValueMarshalerEncodeValue", ValueEncoderFunc(valueMarshalerEncodeValue), + nil, []subtest{ { "wrong type", @@ -642,6 +682,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MarshalerEncodeValue", ValueEncoderFunc(marshalerEncodeValue), + nil, []subtest{ { "wrong type", @@ -704,6 +745,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "PointerCodec.EncodeValue", &pointerCodec{}, + nil, []subtest{ { "nil", @@ -742,6 +784,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "pointer implementation addressable interface", &pointerCodec{}, + nil, []subtest{ { "ValueMarshaler", @@ -764,6 +807,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "JavaScriptEncodeValue", ValueEncoderFunc(javaScriptEncodeValue), + reflectFreeValueEncoderFunc(javaScriptEncodeValueRF), []subtest{ { "wrong type", @@ -779,6 +823,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "SymbolEncodeValue", ValueEncoderFunc(symbolEncodeValue), + reflectFreeValueEncoderFunc(symbolEncodeValueRF), []subtest{ { "wrong type", @@ -794,6 +839,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "BinaryEncodeValue", ValueEncoderFunc(binaryEncodeValue), + reflectFreeValueEncoderFunc(binaryEncodeValueRF), []subtest{ { "wrong type", @@ -809,6 +855,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "UndefinedEncodeValue", ValueEncoderFunc(undefinedEncodeValue), + reflectFreeValueEncoderFunc(undefinedEncodeValueRF), []subtest{ { "wrong type", @@ -824,6 +871,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "DateTimeEncodeValue", ValueEncoderFunc(dateTimeEncodeValue), + reflectFreeValueEncoderFunc(dateTimeEncodeValueRF), []subtest{ { "wrong type", @@ -839,6 +887,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "NullEncodeValue", ValueEncoderFunc(nullEncodeValue), + reflectFreeValueEncoderFunc(nullEncodeValueRF), []subtest{ { "wrong type", @@ -854,6 +903,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "RegexEncodeValue", ValueEncoderFunc(regexEncodeValue), + reflectFreeValueEncoderFunc(regexEncodeValueRF), []subtest{ { "wrong type", @@ -869,6 +919,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "DBPointerEncodeValue", ValueEncoderFunc(dbPointerEncodeValue), + reflectFreeValueEncoderFunc(dbPointerEncodeValueRF), []subtest{ { "wrong type", @@ -891,6 +942,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "TimestampEncodeValue", ValueEncoderFunc(timestampEncodeValue), + reflectFreeValueEncoderFunc(timestampEncodeValueRF), []subtest{ { "wrong type", @@ -906,6 +958,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MinKeyEncodeValue", ValueEncoderFunc(minKeyEncodeValue), + reflectFreeValueEncoderFunc(minKeyEncodeValueRF), []subtest{ { "wrong type", @@ -921,6 +974,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MaxKeyEncodeValue", ValueEncoderFunc(maxKeyEncodeValue), + reflectFreeValueEncoderFunc(maxKeyEncodeValueRF), []subtest{ { "wrong type", @@ -936,6 +990,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "CoreDocumentEncodeValue", ValueEncoderFunc(coreDocumentEncodeValue), + reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF), []subtest{ { "wrong type", @@ -994,6 +1049,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "StructEncodeValue", newStructCodec(&mapCodec{}), + nil, []subtest{ { "interface value", @@ -1016,6 +1072,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "CodeWithScopeEncodeValue", ValueEncoderFunc(codeWithScopeEncodeValue), + reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF), []subtest{ { "wrong type", @@ -1050,7 +1107,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreArrayEncodeValue", - &arrayCodec{}, + ValueEncoderFunc(coreArrayEncodeValue), + reflectFreeValueEncoderFunc(coreArrayEncodeValueRF), []subtest{ { "wrong type", @@ -1110,13 +1168,27 @@ func TestDefaultValueEncoders(t *testing.T) { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) - if !assert.CompareErrors(err, subtest.err) { - t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + + if tc.ve != nil { + err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + if !assert.CompareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } } - invoked := llvrw.invoked - if !cmp.Equal(invoked, subtest.invoke) { - t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + + if tc.rfve != nil { + err := tc.rfve.EncodeValue(ec, llvrw, subtest.val) + if !assert.CompareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } } }) } diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index f42935e5d8..6e11353168 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -38,11 +38,12 @@ func NewMgoRegistry() *Registry { uintCodec := &uintCodec{encodeToMinSize: true} reg := NewRegistry() + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(true)) + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) reg.RegisterKindDecoder(reflect.String, ValueDecoderFunc(mgoStringDecodeValue)) reg.RegisterKindDecoder(reflect.Struct, structCodec) reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Struct, structCodec) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Map, mapCodec) @@ -69,8 +70,9 @@ func NewRespectNilValuesMgoRegistry() *Registry { } reg := NewMgoRegistry() + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) + reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) reg.RegisterKindEncoder(reflect.Map, mapCodec) return reg diff --git a/bson/registry.go b/bson/registry.go index d8f65ddc0d..88c94ff828 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -88,6 +88,8 @@ type Registry struct { kindEncoders *kindEncoderCache kindDecoders *kindDecoderCache typeMap sync.Map // map[Type]reflect.Type + + reflectFreeTypeEncoders *typeReflectFreeEncoderCache } // NewRegistry creates a new empty Registry. @@ -97,6 +99,8 @@ func NewRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -118,6 +122,10 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) r.typeEncoders.Store(valueType, enc) } +func (r *Registry) registerReflectFreeTypeEncoder(valueType reflect.Type, enc reflectFreeValueEncoder) { + r.reflectFreeTypeEncoders.Store(valueType, enc) +} + // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. // // The type will be used as provided, so a decoder can be registered for a type and a different @@ -244,16 +252,28 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, errNoEncoder{Type: valueType} } - enc, found := r.lookupTypeEncoder(valueType) - if found { + + // First attempt to get a user-defined type encoder. + if enc, found := r.lookupTypeEncoder(valueType); found { if enc == nil { return nil, errNoEncoder{Type: valueType} } - return enc, nil + + if _, ok := enc.(defaultValueEncoderFunc); !ok { + return enc, nil + } } - enc, found = r.lookupInterfaceEncoder(valueType, true) - if found { + // Next try to get a reflection-free encoder. + if rfeEnc, found := r.reflectFreeTypeEncoders.Load(valueType); found && rfeEnc != nil { + wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return rfeEnc.EncodeValue(ec, vw, val.Interface()) + } + + return ValueEncoderFunc(wrapper), nil + } + + if enc, found := r.lookupInterfaceEncoder(valueType, true); found { return r.typeEncoders.LoadOrStore(valueType, enc), nil } diff --git a/bson/registry_test.go b/bson/registry_test.go index ea7b2b2ef7..f826ab5e78 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -22,6 +22,8 @@ func newTestRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } } diff --git a/bson/time_codec.go b/bson/time_codec.go index 1c00374c19..85d37496fc 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -99,11 +99,12 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTime { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} +func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val any) error { + timeVal, ok := val.(time.Time) + if !ok { + return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} } - tt := val.Interface().(time.Time) - dt := NewDateTimeFromTime(tt) + + dt := NewDateTimeFromTime(timeVal) return vw.WriteDateTime(int64(dt)) }