diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index bacc99fbb7..7c874c157f 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -138,6 +138,15 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref return fn(ec, vw, val) } +// defaultVauleEncoderFunc is used to wrap the default encoders for determining +// if a registry contains custom data. +type defaultVauleEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error + +// EncodeValue implements the ValueEncoder interface. +func (fn defaultVauleEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return fn(ec, vw, val) +} + // ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index bd5a20f2f9..3fec5d772d 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -59,28 +59,28 @@ func registerDefaultEncoders(reg *Registry) { mapEncoder := &mapCodec{} uintCodec := &uintCodec{} - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeEncoder(tTime, &timeCodec{}) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) - reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) + reg.RegisterTypeEncoder(tByteSlice, defaultVauleEncoderFunc((&byteSliceCodec{}).EncodeValue)) + reg.RegisterTypeEncoder(tTime, defaultVauleEncoderFunc((&timeCodec{}).EncodeValue)) + reg.RegisterTypeEncoder(tEmpty, defaultVauleEncoderFunc((&emptyInterfaceCodec{}).EncodeValue)) + reg.RegisterTypeEncoder(tCoreArray, defaultVauleEncoderFunc((&arrayCodec{}).EncodeValue)) + reg.RegisterTypeEncoder(tOID, defaultVauleEncoderFunc(objectIDEncodeValue)) + reg.RegisterTypeEncoder(tDecimal, defaultVauleEncoderFunc(decimal128EncodeValue)) + reg.RegisterTypeEncoder(tJSONNumber, defaultVauleEncoderFunc(jsonNumberEncodeValue)) + reg.RegisterTypeEncoder(tURL, defaultVauleEncoderFunc(urlEncodeValue)) + reg.RegisterTypeEncoder(tJavaScript, defaultVauleEncoderFunc(javaScriptEncodeValue)) + reg.RegisterTypeEncoder(tSymbol, defaultVauleEncoderFunc(symbolEncodeValue)) + reg.RegisterTypeEncoder(tBinary, defaultVauleEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tVector, defaultVauleEncoderFunc(vectorEncodeValue)) + reg.RegisterTypeEncoder(tUndefined, defaultVauleEncoderFunc(undefinedEncodeValue)) + reg.RegisterTypeEncoder(tDateTime, defaultVauleEncoderFunc(dateTimeEncodeValue)) + reg.RegisterTypeEncoder(tNull, defaultVauleEncoderFunc(nullEncodeValue)) + reg.RegisterTypeEncoder(tRegex, defaultVauleEncoderFunc(regexEncodeValue)) + reg.RegisterTypeEncoder(tDBPointer, defaultVauleEncoderFunc(dbPointerEncodeValue)) + reg.RegisterTypeEncoder(tTimestamp, defaultVauleEncoderFunc(timestampEncodeValue)) + reg.RegisterTypeEncoder(tMinKey, defaultVauleEncoderFunc(minKeyEncodeValue)) + reg.RegisterTypeEncoder(tMaxKey, defaultVauleEncoderFunc(maxKeyEncodeValue)) + reg.RegisterTypeEncoder(tCoreDocument, defaultVauleEncoderFunc(coreDocumentEncodeValue)) + reg.RegisterTypeEncoder(tCodeWithScope, defaultVauleEncoderFunc(codeWithScopeEncodeValue)) reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) diff --git a/bson/registry.go b/bson/registry.go index d8f65ddc0d..3d07c3afe4 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -81,22 +81,24 @@ func (entme errNoTypeMapEntry) Error() string { // // Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. type Registry struct { - interfaceEncoders []interfaceValueEncoder - interfaceDecoders []interfaceValueDecoder - typeEncoders *typeEncoderCache - typeDecoders *typeDecoderCache - kindEncoders *kindEncoderCache - kindDecoders *kindDecoderCache - typeMap sync.Map // map[Type]reflect.Type + interfaceEncoders []interfaceValueEncoder + interfaceDecoders []interfaceValueDecoder + typeEncoders *typeEncoderCache + typeDecoders *typeDecoderCache + kindEncoders *kindEncoderCache + kindDecoders *kindDecoderCache + typeMap sync.Map // map[Type]reflect.Type + defaultTypeEncoders bool } // NewRegistry creates a new empty Registry. func NewRegistry() *Registry { reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + defaultTypeEncoders: true, } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -115,6 +117,9 @@ func NewRegistry() *Registry { // // RegisterTypeEncoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { + if _, ok := enc.(defaultVauleEncoderFunc); !ok { + r.defaultTypeEncoders = false + } r.typeEncoders.Store(valueType, enc) } @@ -268,6 +273,58 @@ func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEnco } func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { + // Check if this is the default registry and handle specific cases + if r.defaultTypeEncoders { + switch rt { + case tFloat64: + return defaultVauleEncoderFunc(floatEncodeValue), true + case tByteSlice: + return defaultVauleEncoderFunc((&byteSliceCodec{}).EncodeValue), true + case tTime: + return defaultVauleEncoderFunc((&timeCodec{}).EncodeValue), true + case tEmpty: + return defaultVauleEncoderFunc((&emptyInterfaceCodec{}).EncodeValue), true + case tCoreArray: + return defaultVauleEncoderFunc((&arrayCodec{}).EncodeValue), true + case tOID: + return defaultVauleEncoderFunc(objectIDEncodeValue), true + case tDecimal: + return defaultVauleEncoderFunc(decimal128EncodeValue), true + case tJSONNumber: + return defaultVauleEncoderFunc(jsonNumberEncodeValue), true + case tURL: + return defaultVauleEncoderFunc(urlEncodeValue), true + case tJavaScript: + return defaultVauleEncoderFunc(javaScriptEncodeValue), true + case tSymbol: + return defaultVauleEncoderFunc(symbolEncodeValue), true + case tBinary: + return defaultVauleEncoderFunc(binaryEncodeValue), true + case tVector: + return defaultVauleEncoderFunc(vectorEncodeValue), true + case tUndefined: + return defaultVauleEncoderFunc(undefinedEncodeValue), true + case tDateTime: + return defaultVauleEncoderFunc(dateTimeEncodeValue), true + case tNull: + return defaultVauleEncoderFunc(nullEncodeValue), true + case tRegex: + return defaultVauleEncoderFunc(regexEncodeValue), true + case tDBPointer: + return defaultVauleEncoderFunc(dbPointerEncodeValue), true + case tTimestamp: + return defaultVauleEncoderFunc(timestampEncodeValue), true + case tMinKey: + return defaultVauleEncoderFunc(minKeyEncodeValue), true + case tMaxKey: + return defaultVauleEncoderFunc(maxKeyEncodeValue), true + case tCoreDocument: + return defaultVauleEncoderFunc(coreDocumentEncodeValue), true + case tCodeWithScope: + return defaultVauleEncoderFunc(codeWithScopeEncodeValue), true + } + } + return r.typeEncoders.Load(rt) } diff --git a/bson/registry_test.go b/bson/registry_test.go index b3a94e6195..77db88d019 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -7,12 +7,16 @@ package bson import ( + "encoding/json" "errors" + "net/url" "reflect" "testing" + "time" "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) // newTestRegistry creates a new Registry. @@ -575,3 +579,80 @@ var _ testInterface3 = (*testInterface3Impl)(nil) func (*testInterface3Impl) test3() {} func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 } + +func BenchmarkLookupTypeEncoder(b *testing.B) { + typesToTest := []reflect.Type{ + reflect.TypeOf(false), // tBool + reflect.TypeOf(float32(0)), // tFloat32 + reflect.TypeOf(float64(0)), // tFloat64 + reflect.TypeOf(int32(0)), // tInt32 + reflect.TypeOf(int64(0)), // tInt64 + reflect.TypeOf(""), // tString + reflect.TypeOf(time.Time{}), // tTime + reflect.TypeOf((*interface{})(nil)).Elem(), // tEmpty + reflect.TypeOf([]byte{}), // tByteSlice + reflect.TypeOf(byte(0x00)), // tByte + reflect.TypeOf(url.URL{}), // tURL + reflect.TypeOf(json.Number("")), // tJSONNumber + reflect.TypeOf(Binary{}), // tBinary + reflect.TypeOf(Undefined{}), // tUndefined + reflect.TypeOf(ObjectID{}), // tOID + reflect.TypeOf(DateTime(0)), // tDateTime + reflect.TypeOf(Null{}), // tNull + reflect.TypeOf(Regex{}), // tRegex + reflect.TypeOf(CodeWithScope{}), // tCodeWithScope + reflect.TypeOf(DBPointer{}), // tDBPointer + reflect.TypeOf(JavaScript("")), // tJavaScript + reflect.TypeOf(Symbol("")), // tSymbol + reflect.TypeOf(Timestamp{}), // tTimestamp + reflect.TypeOf(Decimal128{}), // tDecimal + reflect.TypeOf(Vector{}), // tVector + reflect.TypeOf(MinKey{}), // tMinKey + reflect.TypeOf(MaxKey{}), // tMaxKey + reflect.TypeOf(D{}), // tD + reflect.TypeOf(A{}), // tA + reflect.TypeOf(E{}), // tE + reflect.TypeOf(bsoncore.Document{}), // tCoreDocument + reflect.TypeOf(bsoncore.Array{}), // tCoreArray + } + + // Helper function for running benchmarks with the specified configuration + runBenchmark := func(b *testing.B, name string, defaultEncoders bool) { + b.Run(name, func(b *testing.B) { + reg := NewRegistry() + reg.defaultTypeEncoders = defaultEncoders + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, t := range typesToTest { + reg.lookupTypeEncoder(t) + } + } + }) + } + + // Helper function for running benchmarks concurrently + runBenchmarkAsync := func(b *testing.B, name string, defaultEncoders bool) { + b.Run(name, func(b *testing.B) { + reg := NewRegistry() + reg.defaultTypeEncoders = defaultEncoders + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + for _, t := range typesToTest { + reg.lookupTypeEncoder(t) + } + } + }) + }) + } + + // Sequential benchmarks + runBenchmark(b, "DefaultTypeEncodersTrueSequential", true) + runBenchmark(b, "DefaultTypeEncodersFalseSequential", false) + + // Concurrent benchmarks + runBenchmarkAsync(b, "DefaultTypeEncodersTrueConcurrent", true) + runBenchmarkAsync(b, "DefaultTypeEncodersFalseConcurrent", false) +}