Skip to content

Commit 0e63c74

Browse files
committed
Update Vector interface
1 parent 73018a5 commit 0e63c74

File tree

5 files changed

+245
-176
lines changed

5 files changed

+245
-176
lines changed

bson/bson_binary_vector_spec_test.go

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,23 @@ func Test_BsonBinaryVector(t *testing.T) {
7070
val := Binary{Subtype: TypeBinaryVector}
7171

7272
for _, tc := range [][]byte{
73-
{Float32Vector, 0, 42},
74-
{Float32Vector, 0, 42, 42},
75-
{Float32Vector, 0, 42, 42, 42},
73+
{byte(Float32Vector), 0, 42},
74+
{byte(Float32Vector), 0, 42, 42},
75+
{byte(Float32Vector), 0, 42, 42, 42},
7676

77-
{Float32Vector, 0, 42, 42, 42, 42, 42},
78-
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
79-
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
77+
{byte(Float32Vector), 0, 42, 42, 42, 42, 42},
78+
{byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42},
79+
{byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42, 42},
8080
} {
8181
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
8282
val.Data = tc
8383
b, err := Marshal(D{{"vector", val}})
8484
require.NoError(t, err, "marshaling test BSON")
8585
var got struct {
86-
Vector Vector[float32]
86+
Vector Vector
8787
}
8888
err = Unmarshal(b, &got)
89-
require.ErrorContains(t, err, ErrInsufficientVectorData.Error())
89+
require.ErrorContains(t, err, errInsufficientVectorData.Error())
9090
})
9191
}
9292
})
@@ -95,39 +95,37 @@ func Test_BsonBinaryVector(t *testing.T) {
9595
t.Parallel()
9696

9797
t.Run("Marshaling", func(t *testing.T) {
98-
val := BitVector{Padding: 1}
99-
_, err := Marshal(val)
100-
require.EqualError(t, err, ErrNonZeroVectorPadding.Error())
98+
_, err := NewPackedBitVector(nil, 1)
99+
require.EqualError(t, err, errNonZeroVectorPadding.Error())
101100
})
102101
t.Run("Unmarshaling", func(t *testing.T) {
103-
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
102+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 1}}}}
104103
b, err := Marshal(val)
105104
require.NoError(t, err, "marshaling test BSON")
106105
var got struct {
107-
Vector Vector[float32]
106+
Vector Vector
108107
}
109108
err = Unmarshal(b, &got)
110-
require.ErrorContains(t, err, ErrNonZeroVectorPadding.Error())
109+
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
111110
})
112111
})
113112

114113
t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
115114
t.Parallel()
116115

117116
t.Run("Marshaling", func(t *testing.T) {
118-
val := BitVector{Padding: 8}
119-
_, err := Marshal(val)
120-
require.EqualError(t, err, ErrVectorPaddingTooLarge.Error())
117+
_, err := NewPackedBitVector(nil, 8)
118+
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
121119
})
122120
t.Run("Unmarshaling", func(t *testing.T) {
123-
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
121+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 8}}}}
124122
b, err := Marshal(val)
125123
require.NoError(t, err, "marshaling test BSON")
126124
var got struct {
127-
Vector Vector[float32]
125+
Vector Vector
128126
}
129127
err = Unmarshal(b, &got)
130-
require.ErrorContains(t, err, ErrVectorPaddingTooLarge.Error())
128+
require.ErrorContains(t, err, errVectorPaddingTooLarge.Error())
131129
})
132130
})
133131
}
@@ -156,22 +154,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
156154
t.Skipf("skip invalid case %s", test.Description)
157155
}
158156

159-
var testVector interface{}
157+
testVector := make(map[string]Vector)
160158
switch alias := test.DtypeHex; alias {
161159
case "0x03":
162-
testVector = map[string]Vector[int8]{
163-
testKey: {convertSlice[int8](test.Vector)},
160+
testVector[testKey] = Vector{
161+
dType: Int8Vector,
162+
int8Data: convertSlice[int8](test.Vector),
164163
}
165164
case "0x27":
166-
testVector = map[string]Vector[float32]{
167-
testKey: {convertSlice[float32](test.Vector)},
165+
testVector[testKey] = Vector{
166+
dType: Float32Vector,
167+
float32Data: convertSlice[float32](test.Vector),
168168
}
169169
case "0x10":
170-
testVector = map[string]BitVector{
171-
testKey: {
172-
Padding: uint8(test.Padding),
173-
Data: convertSlice[byte](test.Vector),
174-
},
170+
testVector[testKey] = Vector{
171+
dType: PackedBitVector,
172+
bitData: convertSlice[byte](test.Vector),
173+
bitPadding: uint8(test.Padding),
175174
}
176175
default:
177176
t.Fatalf("unsupported vector type: %s", alias)
@@ -183,18 +182,8 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
183182
t.Run("Unmarshaling", func(t *testing.T) {
184183
t.Parallel()
185184

186-
var got interface{}
187-
switch alias := test.DtypeHex; alias {
188-
case "0x03":
189-
got = make(map[string]Vector[int8])
190-
case "0x27":
191-
got = make(map[string]Vector[float32])
192-
case "0x10":
193-
got = make(map[string]BitVector)
194-
default:
195-
t.Fatalf("unsupported type: %s", alias)
196-
}
197-
err := Unmarshal(testBSON, got)
185+
var got map[string]Vector
186+
err := Unmarshal(testBSON, &got)
198187
require.NoError(t, err)
199188
require.Equal(t, testVector, got)
200189
})

bson/default_value_decoders.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ func registerDefaultDecoders(reg *Registry) {
4242

4343
reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue))
4444
reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType})
45-
reg.RegisterTypeDecoder(tInt8Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
46-
reg.RegisterTypeDecoder(tFloat32Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
47-
reg.RegisterTypeDecoder(tBitVector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
45+
reg.RegisterTypeDecoder(tVector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
4846
reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType})
4947
reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType})
5048
reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType})
@@ -561,10 +559,10 @@ func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro
561559
}
562560

563561
func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
564-
if t != tInt8Vector && t != tFloat32Vector && t != tBitVector {
562+
if t != tVector {
565563
return emptyValue, ValueDecoderError{
566564
Name: "VectorDecodeValue",
567-
Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector},
565+
Types: []reflect.Type{tVector},
568566
Received: reflect.Zero(t),
569567
}
570568
}
@@ -585,10 +583,10 @@ func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.
585583
// vectorDecodeValue is the ValueDecoderFunc for Vector.
586584
func vectorDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
587585
t := val.Type()
588-
if !val.CanSet() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) {
586+
if !val.CanSet() || t != tVector {
589587
return ValueDecoderError{
590588
Name: "VectorDecodeValue",
591-
Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector},
589+
Types: []reflect.Type{tVector},
592590
Received: val,
593591
}
594592
}

bson/default_value_encoders.go

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ func registerDefaultEncoders(reg *Registry) {
7070
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
7171
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
7272
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
73-
reg.RegisterTypeEncoder(tInt8Vector, ValueEncoderFunc(vectorEncodeValue))
74-
reg.RegisterTypeEncoder(tFloat32Vector, ValueEncoderFunc(vectorEncodeValue))
75-
reg.RegisterTypeEncoder(tBitVector, ValueEncoderFunc(vectorEncodeValue))
73+
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
7674
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
7775
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
7876
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
@@ -370,26 +368,14 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error
370368
// vectorEncodeValue is the ValueEncoderFunc for Vector.
371369
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
372370
t := val.Type()
373-
if !val.IsValid() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) {
371+
if !val.IsValid() || t != tVector {
374372
return ValueEncoderError{Name: "VectorEncodeValue",
375-
Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector},
373+
Types: []reflect.Type{tVector},
376374
Received: val,
377375
}
378376
}
379-
var b Binary
380-
var err error
381-
switch v := val.Interface().(type) {
382-
case Vector[int8]:
383-
b, err = NewBinaryFromVector(v)
384-
case Vector[float32]:
385-
b, err = NewBinaryFromVector(v)
386-
case BitVector:
387-
b, err = NewBinaryFromVector(v)
388-
}
389-
if err != nil {
390-
return err
391-
}
392-
377+
v := val.Interface().(Vector)
378+
b := v.Binary()
393379
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
394380
}
395381

bson/types.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ var tJavaScript = reflect.TypeOf(JavaScript(""))
107107
var tSymbol = reflect.TypeOf(Symbol(""))
108108
var tTimestamp = reflect.TypeOf(Timestamp{})
109109
var tDecimal = reflect.TypeOf(Decimal128{})
110-
var tInt8Vector = reflect.TypeOf(Vector[int8]{})
111-
var tFloat32Vector = reflect.TypeOf(Vector[float32]{})
112-
var tBitVector = reflect.TypeOf(BitVector{})
110+
var tVector = reflect.TypeOf(Vector{})
113111
var tMinKey = reflect.TypeOf(MinKey{})
114112
var tMaxKey = reflect.TypeOf(MaxKey{})
115113
var tD = reflect.TypeOf(D{})

0 commit comments

Comments
 (0)