Skip to content

Commit 2f7fc4f

Browse files
committed
update test cases; update error message.
1 parent 0e63c74 commit 2f7fc4f

File tree

2 files changed

+99
-45
lines changed

2 files changed

+99
-45
lines changed

bson/bson_binary_vector_spec_test.go

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type bsonBinaryVectorTestCase struct {
3636
CanonicalBson string `json:"canonical_bson"`
3737
}
3838

39-
func Test_BsonBinaryVector(t *testing.T) {
39+
func TestBsonBinaryVector(t *testing.T) {
4040
t.Parallel()
4141

4242
jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
@@ -91,6 +91,36 @@ func Test_BsonBinaryVector(t *testing.T) {
9191
}
9292
})
9393

94+
t.Run("FLOAT32 with padding", func(t *testing.T) {
95+
t.Parallel()
96+
97+
t.Run("Unmarshaling", func(t *testing.T) {
98+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(Float32Vector), 3}}}}
99+
b, err := Marshal(val)
100+
require.NoError(t, err, "marshaling test BSON")
101+
var got struct {
102+
Vector Vector
103+
}
104+
err = Unmarshal(b, &got)
105+
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
106+
})
107+
})
108+
109+
t.Run("INT8 with padding", func(t *testing.T) {
110+
t.Parallel()
111+
112+
t.Run("Unmarshaling", func(t *testing.T) {
113+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(Int8Vector), 3}}}}
114+
b, err := Marshal(val)
115+
require.NoError(t, err, "marshaling test BSON")
116+
var got struct {
117+
Vector Vector
118+
}
119+
err = Unmarshal(b, &got)
120+
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
121+
})
122+
})
123+
94124
t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
95125
t.Parallel()
96126

@@ -134,13 +164,13 @@ func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
134164
v := make([]T, len(s))
135165
for i, e := range s {
136166
f := math.NaN()
137-
switch v := e.(type) {
167+
switch val := e.(type) {
138168
case float64:
139-
f = v
169+
f = val
140170
case string:
141-
if v == "inf" {
171+
if val == "inf" {
142172
f = math.Inf(0)
143-
} else if v == "-inf" {
173+
} else if val == "-inf" {
144174
f = math.Inf(-1)
145175
}
146176
}
@@ -150,10 +180,6 @@ func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
150180
}
151181

152182
func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
153-
if !test.Valid {
154-
t.Skipf("skip invalid case %s", test.Description)
155-
}
156-
157183
testVector := make(map[string]Vector)
158184
switch alias := test.DtypeHex; alias {
159185
case "0x03":
@@ -180,6 +206,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
180206
require.NoError(t, err, "decoding canonical BSON")
181207

182208
t.Run("Unmarshaling", func(t *testing.T) {
209+
skipCases := map[string]string{
210+
"FLOAT32 with padding": "run in alternative case",
211+
"Overflow Vector INT8": "compile-time restriction",
212+
"Underflow Vector INT8": "compile-time restriction",
213+
"INT8 with padding": "run in alternative case",
214+
"INT8 with float inputs": "compile-time restriction",
215+
"Overflow Vector PACKED_BIT": "compile-time restriction",
216+
"Underflow Vector PACKED_BIT": "compile-time restriction",
217+
"Vector with float values PACKED_BIT": "compile-time restriction",
218+
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
219+
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
220+
"Negative padding PACKED_BIT": "compile-time restriction",
221+
}
222+
if reason, ok := skipCases[test.Description]; ok {
223+
t.Skipf("skip test case %s: %s", test.Description, reason)
224+
}
225+
183226
t.Parallel()
184227

185228
var got map[string]Vector
@@ -189,6 +232,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
189232
})
190233

191234
t.Run("Marshaling", func(t *testing.T) {
235+
skipCases := map[string]string{
236+
"FLOAT32 with padding": "private padding field",
237+
"Overflow Vector INT8": "compile-time restriction",
238+
"Underflow Vector INT8": "compile-time restriction",
239+
"INT8 with padding": "private padding field",
240+
"INT8 with float inputs": "compile-time restriction",
241+
"Overflow Vector PACKED_BIT": "compile-time restriction",
242+
"Underflow Vector PACKED_BIT": "compile-time restriction",
243+
"Vector with float values PACKED_BIT": "compile-time restriction",
244+
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
245+
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
246+
"Negative padding PACKED_BIT": "compile-time restriction",
247+
}
248+
if reason, ok := skipCases[test.Description]; ok {
249+
t.Skipf("skip test case %s: %s", test.Description, reason)
250+
}
251+
192252
t.Parallel()
193253

194254
got, err := Marshal(testVector)

bson/vector.go

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,50 @@ import (
1313
"math"
1414
)
1515

16-
// VectorDType represents the Vector data type.
17-
type VectorDType byte
18-
1916
// These constants are vector data types.
2017
const (
21-
Int8Vector VectorDType = 0x03
22-
Float32Vector VectorDType = 0x27
23-
PackedBitVector VectorDType = 0x10
18+
Int8Vector byte = 0x03
19+
Float32Vector byte = 0x27
20+
PackedBitVector byte = 0x10
2421
)
2522

26-
// Stringer of VectorDType
27-
func (vt VectorDType) String() string {
28-
switch vt {
29-
case Int8Vector:
30-
return "int8"
31-
case Float32Vector:
32-
return "float32"
33-
case PackedBitVector:
34-
return "packed bit"
35-
default:
36-
return "invalid"
37-
}
38-
}
39-
4023
// These are vector conversion errors.
4124
var (
4225
errInsufficientVectorData = errors.New("insufficient data")
4326
errNonZeroVectorPadding = errors.New("padding must be 0")
44-
errVectorPaddingTooLarge = errors.New("padding larger than 7")
27+
errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7")
4528
)
4629

4730
type vectorTypeError struct {
4831
Method string
49-
Type VectorDType
32+
Type byte
5033
}
5134

5235
// Error implements the error interface.
5336
func (vte vectorTypeError) Error() string {
54-
return "Call of " + vte.Method + " on " + vte.Type.String() + " vector"
37+
t := "invalid"
38+
switch vte.Type {
39+
case Int8Vector:
40+
t = "int8"
41+
case Float32Vector:
42+
t = "float32"
43+
case PackedBitVector:
44+
t = "packed bit"
45+
}
46+
return fmt.Sprintf("cannot call %s, on a type %s vector", vte.Method, t)
5547
}
5648

5749
// Vector represents a densely packed array of numbers / bits.
5850
type Vector struct {
59-
dType VectorDType
51+
dType byte
6052
int8Data []int8
6153
float32Data []float32
6254
bitData []byte
6355
bitPadding uint8
6456
}
6557

6658
// Type returns the vector type.
67-
func (v Vector) Type() VectorDType {
59+
func (v Vector) Type() byte {
6860
return v.dType
6961
}
7062

@@ -123,7 +115,7 @@ func (v Vector) PackedBitOK() ([]byte, uint8, bool) {
123115
return v.bitData, v.bitPadding, true
124116
}
125117

126-
// Binary returns the BSON Binary of the Vector.
118+
// Binary returns the BSON Binary representation of the Vector.
127119
func (v Vector) Binary() Binary {
128120
switch v.Type() {
129121
case Int8Vector:
@@ -133,15 +125,17 @@ func (v Vector) Binary() Binary {
133125
case PackedBitVector:
134126
return binaryFromBitVector(v.PackedBit())
135127
default:
136-
panic("invalid Vector type")
128+
panic(fmt.Sprintf("invalid Vector data type: %d", v.dType))
137129
}
138130
}
139131

140132
func binaryFromInt8Vector(v []int8) Binary {
141-
data := make([]byte, 2, len(v)+2)
142-
copy(data, []byte{byte(Int8Vector), 0})
143-
for _, e := range v {
144-
data = append(data, byte(e))
133+
data := make([]byte, len(v)+2)
134+
data[0] = byte(Int8Vector)
135+
data[1] = 0
136+
137+
for i, e := range v {
138+
data[i+2] = byte(e)
145139
}
146140

147141
return Binary{
@@ -180,12 +174,12 @@ func NewVector[T int8 | float32](data []T) Vector {
180174
switch a := any(data).(type) {
181175
case []int8:
182176
v.dType = Int8Vector
183-
v.int8Data = []int8{}
184-
v.int8Data = append(v.int8Data, a...)
177+
v.int8Data = make([]int8, len(data))
178+
copy(v.int8Data, a)
185179
case []float32:
186180
v.dType = Float32Vector
187-
v.float32Data = []float32{}
188-
v.float32Data = append(v.float32Data, a...)
181+
v.float32Data = make([]float32, len(data))
182+
copy(v.float32Data, a)
189183
default:
190184
panic(fmt.Errorf("unsupported type %T", data))
191185
}
@@ -217,7 +211,7 @@ func NewVectorFromBinary(b Binary) (Vector, error) {
217211
if len(b.Data) < 2 {
218212
return v, errInsufficientVectorData
219213
}
220-
switch t := b.Data[0]; VectorDType(t) {
214+
switch t := b.Data[0]; t {
221215
case Int8Vector:
222216
return newInt8Vector(b.Data[1:])
223217
case Float32Vector:

0 commit comments

Comments
 (0)