Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bson/bsoncodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref
return fn(ec, vw, val)
}

// defaultVauleEncoderFunc is used to wrap the default encoders for determining
// if a registry contains custom data.
type defaultVauleEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error

// EncodeValue implements the ValueEncoder interface.
func (fn defaultVauleEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}

// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a
Expand Down
44 changes: 22 additions & 22 deletions bson/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,28 @@ func registerDefaultEncoders(reg *Registry) {
mapEncoder := &mapCodec{}
uintCodec := &uintCodec{}

reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
reg.RegisterTypeEncoder(tTime, &timeCodec{})
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
reg.RegisterTypeEncoder(tByteSlice, defaultVauleEncoderFunc((&byteSliceCodec{}).EncodeValue))
reg.RegisterTypeEncoder(tTime, defaultVauleEncoderFunc((&timeCodec{}).EncodeValue))
reg.RegisterTypeEncoder(tEmpty, defaultVauleEncoderFunc((&emptyInterfaceCodec{}).EncodeValue))
reg.RegisterTypeEncoder(tCoreArray, defaultVauleEncoderFunc((&arrayCodec{}).EncodeValue))
reg.RegisterTypeEncoder(tOID, defaultVauleEncoderFunc(objectIDEncodeValue))
reg.RegisterTypeEncoder(tDecimal, defaultVauleEncoderFunc(decimal128EncodeValue))
reg.RegisterTypeEncoder(tJSONNumber, defaultVauleEncoderFunc(jsonNumberEncodeValue))
reg.RegisterTypeEncoder(tURL, defaultVauleEncoderFunc(urlEncodeValue))
reg.RegisterTypeEncoder(tJavaScript, defaultVauleEncoderFunc(javaScriptEncodeValue))
reg.RegisterTypeEncoder(tSymbol, defaultVauleEncoderFunc(symbolEncodeValue))
reg.RegisterTypeEncoder(tBinary, defaultVauleEncoderFunc(binaryEncodeValue))
reg.RegisterTypeEncoder(tVector, defaultVauleEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tUndefined, defaultVauleEncoderFunc(undefinedEncodeValue))
reg.RegisterTypeEncoder(tDateTime, defaultVauleEncoderFunc(dateTimeEncodeValue))
reg.RegisterTypeEncoder(tNull, defaultVauleEncoderFunc(nullEncodeValue))
reg.RegisterTypeEncoder(tRegex, defaultVauleEncoderFunc(regexEncodeValue))
reg.RegisterTypeEncoder(tDBPointer, defaultVauleEncoderFunc(dbPointerEncodeValue))
reg.RegisterTypeEncoder(tTimestamp, defaultVauleEncoderFunc(timestampEncodeValue))
reg.RegisterTypeEncoder(tMinKey, defaultVauleEncoderFunc(minKeyEncodeValue))
reg.RegisterTypeEncoder(tMaxKey, defaultVauleEncoderFunc(maxKeyEncodeValue))
reg.RegisterTypeEncoder(tCoreDocument, defaultVauleEncoderFunc(coreDocumentEncodeValue))
reg.RegisterTypeEncoder(tCodeWithScope, defaultVauleEncoderFunc(codeWithScopeEncodeValue))
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))
Expand Down
79 changes: 68 additions & 11 deletions bson/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,24 @@ func (entme errNoTypeMapEntry) Error() string {
//
// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure.
type Registry struct {
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
typeEncoders *typeEncoderCache
typeDecoders *typeDecoderCache
kindEncoders *kindEncoderCache
kindDecoders *kindDecoderCache
typeMap sync.Map // map[Type]reflect.Type
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
typeEncoders *typeEncoderCache
typeDecoders *typeDecoderCache
kindEncoders *kindEncoderCache
kindDecoders *kindDecoderCache
typeMap sync.Map // map[Type]reflect.Type
defaultTypeEncoders bool
}

// NewRegistry creates a new empty Registry.
func NewRegistry() *Registry {
reg := &Registry{
typeEncoders: new(typeEncoderCache),
typeDecoders: new(typeDecoderCache),
kindEncoders: new(kindEncoderCache),
kindDecoders: new(kindDecoderCache),
typeEncoders: new(typeEncoderCache),
typeDecoders: new(typeDecoderCache),
kindEncoders: new(kindEncoderCache),
kindDecoders: new(kindDecoderCache),
defaultTypeEncoders: true,
}
registerDefaultEncoders(reg)
registerDefaultDecoders(reg)
Expand All @@ -115,6 +117,9 @@ func NewRegistry() *Registry {
//
// RegisterTypeEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) {
if _, ok := enc.(defaultVauleEncoderFunc); !ok {
r.defaultTypeEncoders = false
}
r.typeEncoders.Store(valueType, enc)
}

Expand Down Expand Up @@ -268,6 +273,58 @@ func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEnco
}

func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) {
// Check if this is the default registry and handle specific cases
if r.defaultTypeEncoders {
switch rt {
case tFloat64:
return defaultVauleEncoderFunc(floatEncodeValue), true
case tByteSlice:
return defaultVauleEncoderFunc((&byteSliceCodec{}).EncodeValue), true
case tTime:
return defaultVauleEncoderFunc((&timeCodec{}).EncodeValue), true
case tEmpty:
return defaultVauleEncoderFunc((&emptyInterfaceCodec{}).EncodeValue), true
case tCoreArray:
return defaultVauleEncoderFunc((&arrayCodec{}).EncodeValue), true
case tOID:
return defaultVauleEncoderFunc(objectIDEncodeValue), true
case tDecimal:
return defaultVauleEncoderFunc(decimal128EncodeValue), true
case tJSONNumber:
return defaultVauleEncoderFunc(jsonNumberEncodeValue), true
case tURL:
return defaultVauleEncoderFunc(urlEncodeValue), true
case tJavaScript:
return defaultVauleEncoderFunc(javaScriptEncodeValue), true
case tSymbol:
return defaultVauleEncoderFunc(symbolEncodeValue), true
case tBinary:
return defaultVauleEncoderFunc(binaryEncodeValue), true
case tVector:
return defaultVauleEncoderFunc(vectorEncodeValue), true
case tUndefined:
return defaultVauleEncoderFunc(undefinedEncodeValue), true
case tDateTime:
return defaultVauleEncoderFunc(dateTimeEncodeValue), true
case tNull:
return defaultVauleEncoderFunc(nullEncodeValue), true
case tRegex:
return defaultVauleEncoderFunc(regexEncodeValue), true
case tDBPointer:
return defaultVauleEncoderFunc(dbPointerEncodeValue), true
case tTimestamp:
return defaultVauleEncoderFunc(timestampEncodeValue), true
case tMinKey:
return defaultVauleEncoderFunc(minKeyEncodeValue), true
case tMaxKey:
return defaultVauleEncoderFunc(maxKeyEncodeValue), true
case tCoreDocument:
return defaultVauleEncoderFunc(coreDocumentEncodeValue), true
case tCodeWithScope:
return defaultVauleEncoderFunc(codeWithScopeEncodeValue), true
}
}

return r.typeEncoders.Load(rt)
}

Expand Down
81 changes: 81 additions & 0 deletions bson/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
package bson

import (
"encoding/json"
"errors"
"net/url"
"reflect"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)

// newTestRegistry creates a new Registry.
Expand Down Expand Up @@ -575,3 +579,80 @@ var _ testInterface3 = (*testInterface3Impl)(nil)
func (*testInterface3Impl) test3() {}

func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }

func BenchmarkLookupTypeEncoder(b *testing.B) {
typesToTest := []reflect.Type{
reflect.TypeOf(false), // tBool
reflect.TypeOf(float32(0)), // tFloat32
reflect.TypeOf(float64(0)), // tFloat64
reflect.TypeOf(int32(0)), // tInt32
reflect.TypeOf(int64(0)), // tInt64
reflect.TypeOf(""), // tString
reflect.TypeOf(time.Time{}), // tTime
reflect.TypeOf((*interface{})(nil)).Elem(), // tEmpty
reflect.TypeOf([]byte{}), // tByteSlice
reflect.TypeOf(byte(0x00)), // tByte
reflect.TypeOf(url.URL{}), // tURL
reflect.TypeOf(json.Number("")), // tJSONNumber
reflect.TypeOf(Binary{}), // tBinary
reflect.TypeOf(Undefined{}), // tUndefined
reflect.TypeOf(ObjectID{}), // tOID
reflect.TypeOf(DateTime(0)), // tDateTime
reflect.TypeOf(Null{}), // tNull
reflect.TypeOf(Regex{}), // tRegex
reflect.TypeOf(CodeWithScope{}), // tCodeWithScope
reflect.TypeOf(DBPointer{}), // tDBPointer
reflect.TypeOf(JavaScript("")), // tJavaScript
reflect.TypeOf(Symbol("")), // tSymbol
reflect.TypeOf(Timestamp{}), // tTimestamp
reflect.TypeOf(Decimal128{}), // tDecimal
reflect.TypeOf(Vector{}), // tVector
reflect.TypeOf(MinKey{}), // tMinKey
reflect.TypeOf(MaxKey{}), // tMaxKey
reflect.TypeOf(D{}), // tD
reflect.TypeOf(A{}), // tA
reflect.TypeOf(E{}), // tE
reflect.TypeOf(bsoncore.Document{}), // tCoreDocument
reflect.TypeOf(bsoncore.Array{}), // tCoreArray
}

// Helper function for running benchmarks with the specified configuration
runBenchmark := func(b *testing.B, name string, defaultEncoders bool) {
b.Run(name, func(b *testing.B) {
reg := NewRegistry()
reg.defaultTypeEncoders = defaultEncoders
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, t := range typesToTest {
reg.lookupTypeEncoder(t)
}
}
})
}

// Helper function for running benchmarks concurrently
runBenchmarkAsync := func(b *testing.B, name string, defaultEncoders bool) {
b.Run(name, func(b *testing.B) {
reg := NewRegistry()
reg.defaultTypeEncoders = defaultEncoders
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
for _, t := range typesToTest {
reg.lookupTypeEncoder(t)
}
}
})
})
}

// Sequential benchmarks
runBenchmark(b, "DefaultTypeEncodersTrueSequential", true)
runBenchmark(b, "DefaultTypeEncodersFalseSequential", false)

// Concurrent benchmarks
runBenchmarkAsync(b, "DefaultTypeEncodersTrueConcurrent", true)
runBenchmarkAsync(b, "DefaultTypeEncodersFalseConcurrent", false)
}
Loading