Skip to content

Commit 70b5987

Browse files
Use switch condition on default registry to load encoder
1 parent ee212da commit 70b5987

File tree

4 files changed

+180
-33
lines changed

4 files changed

+180
-33
lines changed

bson/bsoncodec.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,15 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref
138138
return fn(ec, vw, val)
139139
}
140140

141+
// defaultVauleEncoderFunc is used to wrap the default encoders for determining
142+
// if a registry contains custom data.
143+
type defaultVauleEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error
144+
145+
// EncodeValue implements the ValueEncoder interface.
146+
func (fn defaultVauleEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
147+
return fn(ec, vw, val)
148+
}
149+
141150
// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
142151
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
143152
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a

bson/default_value_encoders.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,28 +59,28 @@ func registerDefaultEncoders(reg *Registry) {
5959
mapEncoder := &mapCodec{}
6060
uintCodec := &uintCodec{}
6161

62-
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
63-
reg.RegisterTypeEncoder(tTime, &timeCodec{})
64-
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
65-
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
66-
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
67-
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
68-
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
69-
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
70-
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
71-
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
72-
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
73-
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
74-
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
75-
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
76-
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
77-
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
78-
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
79-
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
80-
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
81-
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
82-
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
83-
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
62+
reg.RegisterTypeEncoder(tByteSlice, defaultVauleEncoderFunc((&byteSliceCodec{}).EncodeValue))
63+
reg.RegisterTypeEncoder(tTime, defaultVauleEncoderFunc((&timeCodec{}).EncodeValue))
64+
reg.RegisterTypeEncoder(tEmpty, defaultVauleEncoderFunc((&emptyInterfaceCodec{}).EncodeValue))
65+
reg.RegisterTypeEncoder(tCoreArray, defaultVauleEncoderFunc((&arrayCodec{}).EncodeValue))
66+
reg.RegisterTypeEncoder(tOID, defaultVauleEncoderFunc(objectIDEncodeValue))
67+
reg.RegisterTypeEncoder(tDecimal, defaultVauleEncoderFunc(decimal128EncodeValue))
68+
reg.RegisterTypeEncoder(tJSONNumber, defaultVauleEncoderFunc(jsonNumberEncodeValue))
69+
reg.RegisterTypeEncoder(tURL, defaultVauleEncoderFunc(urlEncodeValue))
70+
reg.RegisterTypeEncoder(tJavaScript, defaultVauleEncoderFunc(javaScriptEncodeValue))
71+
reg.RegisterTypeEncoder(tSymbol, defaultVauleEncoderFunc(symbolEncodeValue))
72+
reg.RegisterTypeEncoder(tBinary, defaultVauleEncoderFunc(binaryEncodeValue))
73+
reg.RegisterTypeEncoder(tVector, defaultVauleEncoderFunc(vectorEncodeValue))
74+
reg.RegisterTypeEncoder(tUndefined, defaultVauleEncoderFunc(undefinedEncodeValue))
75+
reg.RegisterTypeEncoder(tDateTime, defaultVauleEncoderFunc(dateTimeEncodeValue))
76+
reg.RegisterTypeEncoder(tNull, defaultVauleEncoderFunc(nullEncodeValue))
77+
reg.RegisterTypeEncoder(tRegex, defaultVauleEncoderFunc(regexEncodeValue))
78+
reg.RegisterTypeEncoder(tDBPointer, defaultVauleEncoderFunc(dbPointerEncodeValue))
79+
reg.RegisterTypeEncoder(tTimestamp, defaultVauleEncoderFunc(timestampEncodeValue))
80+
reg.RegisterTypeEncoder(tMinKey, defaultVauleEncoderFunc(minKeyEncodeValue))
81+
reg.RegisterTypeEncoder(tMaxKey, defaultVauleEncoderFunc(maxKeyEncodeValue))
82+
reg.RegisterTypeEncoder(tCoreDocument, defaultVauleEncoderFunc(coreDocumentEncodeValue))
83+
reg.RegisterTypeEncoder(tCodeWithScope, defaultVauleEncoderFunc(codeWithScopeEncodeValue))
8484
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
8585
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
8686
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))

bson/registry.go

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,24 @@ func (entme errNoTypeMapEntry) Error() string {
8181
//
8282
// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure.
8383
type Registry struct {
84-
interfaceEncoders []interfaceValueEncoder
85-
interfaceDecoders []interfaceValueDecoder
86-
typeEncoders *typeEncoderCache
87-
typeDecoders *typeDecoderCache
88-
kindEncoders *kindEncoderCache
89-
kindDecoders *kindDecoderCache
90-
typeMap sync.Map // map[Type]reflect.Type
84+
interfaceEncoders []interfaceValueEncoder
85+
interfaceDecoders []interfaceValueDecoder
86+
typeEncoders *typeEncoderCache
87+
typeDecoders *typeDecoderCache
88+
kindEncoders *kindEncoderCache
89+
kindDecoders *kindDecoderCache
90+
typeMap sync.Map // map[Type]reflect.Type
91+
defaultTypeEncoders bool
9192
}
9293

9394
// NewRegistry creates a new empty Registry.
9495
func NewRegistry() *Registry {
9596
reg := &Registry{
96-
typeEncoders: new(typeEncoderCache),
97-
typeDecoders: new(typeDecoderCache),
98-
kindEncoders: new(kindEncoderCache),
99-
kindDecoders: new(kindDecoderCache),
97+
typeEncoders: new(typeEncoderCache),
98+
typeDecoders: new(typeDecoderCache),
99+
kindEncoders: new(kindEncoderCache),
100+
kindDecoders: new(kindDecoderCache),
101+
defaultTypeEncoders: true,
100102
}
101103
registerDefaultEncoders(reg)
102104
registerDefaultDecoders(reg)
@@ -115,6 +117,9 @@ func NewRegistry() *Registry {
115117
//
116118
// RegisterTypeEncoder should not be called concurrently with any other Registry method.
117119
func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) {
120+
if _, ok := enc.(defaultVauleEncoderFunc); !ok {
121+
r.defaultTypeEncoders = false
122+
}
118123
r.typeEncoders.Store(valueType, enc)
119124
}
120125

@@ -268,6 +273,58 @@ func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEnco
268273
}
269274

270275
func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) {
276+
// Check if this is the default registry and handle specific cases
277+
if r.defaultTypeEncoders {
278+
switch rt {
279+
case tFloat64:
280+
return defaultVauleEncoderFunc(floatEncodeValue), true
281+
case tByteSlice:
282+
return defaultVauleEncoderFunc((&byteSliceCodec{}).EncodeValue), true
283+
case tTime:
284+
return defaultVauleEncoderFunc((&timeCodec{}).EncodeValue), true
285+
case tEmpty:
286+
return defaultVauleEncoderFunc((&emptyInterfaceCodec{}).EncodeValue), true
287+
case tCoreArray:
288+
return defaultVauleEncoderFunc((&arrayCodec{}).EncodeValue), true
289+
case tOID:
290+
return defaultVauleEncoderFunc(objectIDEncodeValue), true
291+
case tDecimal:
292+
return defaultVauleEncoderFunc(decimal128EncodeValue), true
293+
case tJSONNumber:
294+
return defaultVauleEncoderFunc(jsonNumberEncodeValue), true
295+
case tURL:
296+
return defaultVauleEncoderFunc(urlEncodeValue), true
297+
case tJavaScript:
298+
return defaultVauleEncoderFunc(javaScriptEncodeValue), true
299+
case tSymbol:
300+
return defaultVauleEncoderFunc(symbolEncodeValue), true
301+
case tBinary:
302+
return defaultVauleEncoderFunc(binaryEncodeValue), true
303+
case tVector:
304+
return defaultVauleEncoderFunc(vectorEncodeValue), true
305+
case tUndefined:
306+
return defaultVauleEncoderFunc(undefinedEncodeValue), true
307+
case tDateTime:
308+
return defaultVauleEncoderFunc(dateTimeEncodeValue), true
309+
case tNull:
310+
return defaultVauleEncoderFunc(nullEncodeValue), true
311+
case tRegex:
312+
return defaultVauleEncoderFunc(regexEncodeValue), true
313+
case tDBPointer:
314+
return defaultVauleEncoderFunc(dbPointerEncodeValue), true
315+
case tTimestamp:
316+
return defaultVauleEncoderFunc(timestampEncodeValue), true
317+
case tMinKey:
318+
return defaultVauleEncoderFunc(minKeyEncodeValue), true
319+
case tMaxKey:
320+
return defaultVauleEncoderFunc(maxKeyEncodeValue), true
321+
case tCoreDocument:
322+
return defaultVauleEncoderFunc(coreDocumentEncodeValue), true
323+
case tCodeWithScope:
324+
return defaultVauleEncoderFunc(codeWithScopeEncodeValue), true
325+
}
326+
}
327+
271328
return r.typeEncoders.Load(rt)
272329
}
273330

bson/registry_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
package bson
88

99
import (
10+
"encoding/json"
1011
"errors"
12+
"net/url"
1113
"reflect"
1214
"testing"
15+
"time"
1316

1417
"github.com/google/go-cmp/cmp"
1518
"go.mongodb.org/mongo-driver/v2/internal/assert"
19+
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
1620
)
1721

1822
// newTestRegistry creates a new Registry.
@@ -575,3 +579,80 @@ var _ testInterface3 = (*testInterface3Impl)(nil)
575579
func (*testInterface3Impl) test3() {}
576580

577581
func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }
582+
583+
func BenchmarkLookupTypeEncoder(b *testing.B) {
584+
typesToTest := []reflect.Type{
585+
reflect.TypeOf(false), // tBool
586+
reflect.TypeOf(float32(0)), // tFloat32
587+
reflect.TypeOf(float64(0)), // tFloat64
588+
reflect.TypeOf(int32(0)), // tInt32
589+
reflect.TypeOf(int64(0)), // tInt64
590+
reflect.TypeOf(""), // tString
591+
reflect.TypeOf(time.Time{}), // tTime
592+
reflect.TypeOf((*interface{})(nil)).Elem(), // tEmpty
593+
reflect.TypeOf([]byte{}), // tByteSlice
594+
reflect.TypeOf(byte(0x00)), // tByte
595+
reflect.TypeOf(url.URL{}), // tURL
596+
reflect.TypeOf(json.Number("")), // tJSONNumber
597+
reflect.TypeOf(Binary{}), // tBinary
598+
reflect.TypeOf(Undefined{}), // tUndefined
599+
reflect.TypeOf(ObjectID{}), // tOID
600+
reflect.TypeOf(DateTime(0)), // tDateTime
601+
reflect.TypeOf(Null{}), // tNull
602+
reflect.TypeOf(Regex{}), // tRegex
603+
reflect.TypeOf(CodeWithScope{}), // tCodeWithScope
604+
reflect.TypeOf(DBPointer{}), // tDBPointer
605+
reflect.TypeOf(JavaScript("")), // tJavaScript
606+
reflect.TypeOf(Symbol("")), // tSymbol
607+
reflect.TypeOf(Timestamp{}), // tTimestamp
608+
reflect.TypeOf(Decimal128{}), // tDecimal
609+
reflect.TypeOf(Vector{}), // tVector
610+
reflect.TypeOf(MinKey{}), // tMinKey
611+
reflect.TypeOf(MaxKey{}), // tMaxKey
612+
reflect.TypeOf(D{}), // tD
613+
reflect.TypeOf(A{}), // tA
614+
reflect.TypeOf(E{}), // tE
615+
reflect.TypeOf(bsoncore.Document{}), // tCoreDocument
616+
reflect.TypeOf(bsoncore.Array{}), // tCoreArray
617+
}
618+
619+
// Helper function for running benchmarks with the specified configuration
620+
runBenchmark := func(b *testing.B, name string, defaultEncoders bool) {
621+
b.Run(name, func(b *testing.B) {
622+
reg := NewRegistry()
623+
reg.defaultTypeEncoders = defaultEncoders
624+
b.ReportAllocs()
625+
b.ResetTimer()
626+
for i := 0; i < b.N; i++ {
627+
for _, t := range typesToTest {
628+
reg.lookupTypeEncoder(t)
629+
}
630+
}
631+
})
632+
}
633+
634+
// Helper function for running benchmarks concurrently
635+
runBenchmarkAsync := func(b *testing.B, name string, defaultEncoders bool) {
636+
b.Run(name, func(b *testing.B) {
637+
reg := NewRegistry()
638+
reg.defaultTypeEncoders = defaultEncoders
639+
b.ReportAllocs()
640+
b.ResetTimer()
641+
b.RunParallel(func(pb *testing.PB) {
642+
for pb.Next() {
643+
for _, t := range typesToTest {
644+
reg.lookupTypeEncoder(t)
645+
}
646+
}
647+
})
648+
})
649+
}
650+
651+
// Sequential benchmarks
652+
runBenchmark(b, "DefaultTypeEncodersTrueSequential", true)
653+
runBenchmark(b, "DefaultTypeEncodersFalseSequential", false)
654+
655+
// Concurrent benchmarks
656+
runBenchmarkAsync(b, "DefaultTypeEncodersTrueConcurrent", true)
657+
runBenchmarkAsync(b, "DefaultTypeEncodersFalseConcurrent", false)
658+
}

0 commit comments

Comments
 (0)