diff --git a/bson/bson_binary_vector_spec_test.go b/bson/bson_binary_vector_spec_test.go new file mode 100644 index 0000000000..b30aac1a83 --- /dev/null +++ b/bson/bson_binary_vector_spec_test.go @@ -0,0 +1,259 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "math" + "os" + "path" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/require" +) + +const bsonBinaryVectorDir = "../testdata/bson-binary-vector/" + +type bsonBinaryVectorTests struct { + Description string `json:"description"` + TestKey string `json:"test_key"` + Tests []bsonBinaryVectorTestCase `json:"tests"` +} + +type bsonBinaryVectorTestCase struct { + Description string `json:"description"` + Valid bool `json:"valid"` + Vector []interface{} `json:"vector"` + DtypeHex string `json:"dtype_hex"` + DtypeAlias string `json:"dtype_alias"` + Padding int `json:"padding"` + CanonicalBson string `json:"canonical_bson"` +} + +func TestBsonBinaryVectorSpec(t *testing.T) { + t.Parallel() + + jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir) + require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err) + + for _, file := range jsonFiles { + filepath := path.Join(bsonBinaryVectorDir, file) + content, err := os.ReadFile(filepath) + require.NoErrorf(t, err, "reading test file %s", filepath) + + var tests bsonBinaryVectorTests + require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath) + + t.Run(tests.Description, func(t *testing.T) { + t.Parallel() + + for _, test := range tests.Tests { + test := test + t.Run(test.Description, func(t *testing.T) { + t.Parallel() + + runBsonBinaryVectorTest(t, tests.TestKey, test) + }) + } + }) + } + + t.Run("FLOAT32 with padding", func(t *testing.T) { + t.Parallel() + + t.Run("Unmarshaling", func(t *testing.T) { + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Float32Vector, 3}}}} + b, err := Marshal(val) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, errNonZeroVectorPadding.Error()) + }) + }) + + t.Run("INT8 with padding", func(t *testing.T) { + t.Parallel() + + t.Run("Unmarshaling", func(t *testing.T) { + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Int8Vector, 3}}}} + b, err := Marshal(val) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, errNonZeroVectorPadding.Error()) + }) + }) + + t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) { + t.Parallel() + + t.Run("Marshaling", func(t *testing.T) { + _, err := NewPackedBitVector(nil, 1) + require.EqualError(t, err, errNonZeroVectorPadding.Error()) + }) + t.Run("Unmarshaling", func(t *testing.T) { + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}} + b, err := Marshal(val) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, errNonZeroVectorPadding.Error()) + }) + }) + + t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) { + t.Parallel() + + t.Run("Marshaling", func(t *testing.T) { + _, err := NewPackedBitVector(nil, 8) + require.EqualError(t, err, errVectorPaddingTooLarge.Error()) + }) + t.Run("Unmarshaling", func(t *testing.T) { + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}} + b, err := Marshal(val) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, errVectorPaddingTooLarge.Error()) + }) + }) +} + +// TODO: This test may be added into the spec tests. +func TestFloat32VectorWithInsufficientData(t *testing.T) { + t.Parallel() + + val := Binary{Subtype: TypeBinaryVector} + + for _, tc := range [][]byte{ + {Float32Vector, 0, 42}, + {Float32Vector, 0, 42, 42}, + {Float32Vector, 0, 42, 42, 42}, + + {Float32Vector, 0, 42, 42, 42, 42, 42}, + {Float32Vector, 0, 42, 42, 42, 42, 42, 42}, + {Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42}, + } { + t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) { + val.Data = tc + b, err := Marshal(D{{"vector", val}}) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, errInsufficientVectorData.Error()) + }) + } +} + +func convertSlice[T int8 | float32 | byte](s []interface{}) []T { + v := make([]T, len(s)) + for i, e := range s { + f := math.NaN() + switch val := e.(type) { + case float64: + f = val + case string: + if val == "inf" { + f = math.Inf(0) + } else if val == "-inf" { + f = math.Inf(-1) + } + } + v[i] = T(f) + } + return v +} + +func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) { + testVector := make(map[string]Vector) + switch alias := test.DtypeHex; alias { + case "0x03": + testVector[testKey] = Vector{ + dType: Int8Vector, + int8Data: convertSlice[int8](test.Vector), + } + case "0x27": + testVector[testKey] = Vector{ + dType: Float32Vector, + float32Data: convertSlice[float32](test.Vector), + } + case "0x10": + testVector[testKey] = Vector{ + dType: PackedBitVector, + bitData: convertSlice[byte](test.Vector), + bitPadding: uint8(test.Padding), + } + default: + t.Fatalf("unsupported vector type: %s", alias) + } + + testBSON, err := hex.DecodeString(test.CanonicalBson) + require.NoError(t, err, "decoding canonical BSON") + + t.Run("Unmarshaling", func(t *testing.T) { + skipCases := map[string]string{ + "FLOAT32 with padding": "run in alternative case", + "Overflow Vector INT8": "compile-time restriction", + "Underflow Vector INT8": "compile-time restriction", + "INT8 with padding": "run in alternative case", + "INT8 with float inputs": "compile-time restriction", + "Overflow Vector PACKED_BIT": "compile-time restriction", + "Underflow Vector PACKED_BIT": "compile-time restriction", + "Vector with float values PACKED_BIT": "compile-time restriction", + "Padding specified with no vector data PACKED_BIT": "run in alternative case", + "Exceeding maximum padding PACKED_BIT": "run in alternative case", + "Negative padding PACKED_BIT": "compile-time restriction", + } + if reason, ok := skipCases[test.Description]; ok { + t.Skipf("skip test case %s: %s", test.Description, reason) + } + + t.Parallel() + + var got map[string]Vector + err := Unmarshal(testBSON, &got) + require.NoError(t, err) + require.Equal(t, testVector, got) + }) + + t.Run("Marshaling", func(t *testing.T) { + skipCases := map[string]string{ + "FLOAT32 with padding": "private padding field", + "Overflow Vector INT8": "compile-time restriction", + "Underflow Vector INT8": "compile-time restriction", + "INT8 with padding": "private padding field", + "INT8 with float inputs": "compile-time restriction", + "Overflow Vector PACKED_BIT": "compile-time restriction", + "Underflow Vector PACKED_BIT": "compile-time restriction", + "Vector with float values PACKED_BIT": "compile-time restriction", + "Padding specified with no vector data PACKED_BIT": "run in alternative case", + "Exceeding maximum padding PACKED_BIT": "run in alternative case", + "Negative padding PACKED_BIT": "compile-time restriction", + } + if reason, ok := skipCases[test.Description]; ok { + t.Skipf("skip test case %s: %s", test.Description, reason) + } + + t.Parallel() + + got, err := Marshal(testVector) + require.NoError(t, err) + require.Equal(t, testBSON, got) + }) +} diff --git a/bson/bson_corpus_spec_test.go b/bson/bson_corpus_spec_test.go index a0d5a5aa38..043aa2f019 100644 --- a/bson/bson_corpus_spec_test.go +++ b/bson/bson_corpus_spec_test.go @@ -217,7 +217,7 @@ func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string { func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D { var doc D err := Unmarshal(b, &doc) - expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType)) + require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType) return doc } @@ -225,7 +225,7 @@ func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D { // canonical BSON (cB) func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) { actual, err := Marshal(doc) - expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType)) + require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType) if diff := cmp.Diff(cB, actual); diff != "" { t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n", @@ -261,7 +261,7 @@ func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) { // nativeToJSON encodes the native Document (doc) into an extended JSON string func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) { actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true) - expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType)) + require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType) if diff := cmp.Diff(ej, string(actualEJ)); diff != "" { t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n", @@ -288,7 +288,7 @@ func runTest(t *testing.T, file string) { t.Run(v.Description, func(t *testing.T) { // get canonical BSON cB, err := hex.DecodeString(v.CanonicalBson) - expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description)) + require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description) // get canonical extended JSON var compactEJ bytes.Buffer @@ -341,7 +341,7 @@ func runTest(t *testing.T, file string) { /*** degenerate BSON round-trip tests (if exists) ***/ if v.DegenerateBSON != nil { dB, err := hex.DecodeString(*v.DegenerateBSON) - expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description)) + require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description) doc = bsonToNative(t, dB, "degenerate", v.Description) @@ -377,7 +377,7 @@ func runTest(t *testing.T, file string) { for _, d := range test.DecodeErrors { t.Run(d.Description, func(t *testing.T) { b, err := hex.DecodeString(d.Bson) - expectNoError(t, err, d.Description) + require.NoError(t, err, d.Description) var doc D err = Unmarshal(b, &doc) @@ -392,12 +392,12 @@ func runTest(t *testing.T, file string) { invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB) if invalidString || invalidDBPtr { - expectNoError(t, err, d.Description) + require.NoError(t, err, d.Description) return } } - expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description)) + require.Errorf(t, err, "%s: expected decode error", d.Description) }) } }) @@ -418,7 +418,7 @@ func runTest(t *testing.T, file string) { if strings.Contains(p.Description, "Null") { _, err = Marshal(doc) } - expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description)) + require.Errorf(t, err, "%s: expected parse error", p.Description) default: t.Errorf("Update test to check for parse errors for type %s", test.BsonType) t.Fail() @@ -431,31 +431,13 @@ func runTest(t *testing.T, file string) { func Test_BsonCorpus(t *testing.T) { jsonFiles, err := findJSONFilesInDir(dataDir) - if err != nil { - t.Fatalf("error finding JSON files in %s: %v", dataDir, err) - } + require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err) for _, file := range jsonFiles { runTest(t, file) } } -func expectNoError(t *testing.T, err error, desc string) { - if err != nil { - t.Helper() - t.Errorf("%s: Unepexted error: %v", desc, err) - t.FailNow() - } -} - -func expectError(t *testing.T, err error, desc string) { - if err == nil { - t.Helper() - t.Errorf("%s: Expected error", desc) - t.FailNow() - } -} - func TestRelaxedUUIDValidation(t *testing.T) { testCases := []struct { description string diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 2f195329ca..dfff145219 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -42,6 +42,7 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue)) reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType}) + reg.RegisterTypeDecoder(tVector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType}) reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType}) reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType}) @@ -501,14 +502,8 @@ func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) er return nil } -func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - if t != tBinary { - return emptyValue, ValueDecoderError{ - Name: "BinaryDecodeValue", - Types: []reflect.Type{tBinary}, - Received: reflect.Zero(t), - } - } +func binaryDecode(vr ValueReader) (Binary, error) { + var b Binary var data []byte var subtype byte @@ -521,13 +516,31 @@ func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. case TypeUndefined: err = vr.ReadUndefined() default: - return emptyValue, fmt.Errorf("cannot decode %v into a Binary", vrType) + return b, fmt.Errorf("cannot decode %v into a Binary", vrType) } if err != nil { - return emptyValue, err + return b, err + } + b.Subtype = subtype + b.Data = data + + return b, nil +} + +func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tBinary { + return emptyValue, ValueDecoderError{ + Name: "BinaryDecodeValue", + Types: []reflect.Type{tBinary}, + Received: reflect.Zero(t), + } } - return reflect.ValueOf(Binary{Subtype: subtype, Data: data}), nil + b, err := binaryDecode(vr) + if err != nil { + return emptyValue, err + } + return reflect.ValueOf(b), nil } // binaryDecodeValue is the ValueDecoderFunc for Binary. @@ -545,6 +558,48 @@ func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro return nil } +func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tVector { + return emptyValue, ValueDecoderError{ + Name: "VectorDecodeValue", + Types: []reflect.Type{tVector}, + Received: reflect.Zero(t), + } + } + + b, err := binaryDecode(vr) + if err != nil { + return emptyValue, err + } + + v, err := NewVectorFromBinary(b) + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(v), nil +} + +// vectorDecodeValue is the ValueDecoderFunc for Vector. +func vectorDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { + t := val.Type() + if !val.CanSet() || t != tVector { + return ValueDecoderError{ + Name: "VectorDecodeValue", + Types: []reflect.Type{tVector}, + Received: val, + } + } + + elem, err := vectorDecodeType(dctx, vr, t) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 9835738be3..bd5a20f2f9 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -70,6 +70,7 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue)) reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) @@ -364,6 +365,20 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } +// vectorEncodeValue is the ValueEncoderFunc for Vector. +func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + t := val.Type() + if !val.IsValid() || t != tVector { + return ValueEncoderError{Name: "VectorEncodeValue", + Types: []reflect.Type{tVector}, + Received: val, + } + } + v := val.Interface().(Vector) + b := v.Binary() + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +} + // undefinedEncodeValue is the ValueEncoderFunc for Undefined. func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { diff --git a/bson/extjson_parser_test.go b/bson/extjson_parser_test.go index 51a7de7aef..ae6bbf7bb3 100644 --- a/bson/extjson_parser_test.go +++ b/bson/extjson_parser_test.go @@ -45,6 +45,22 @@ type readKeyValueTestCase struct { valEFs []expectedErrorFunc } +func expectNoError(t *testing.T, err error, desc string) { + if err != nil { + t.Helper() + t.Errorf("%s: Unepexted error: %v", desc, err) + t.FailNow() + } +} + +func expectError(t *testing.T, err error, desc string) { + if err == nil { + t.Helper() + t.Errorf("%s: Expected error", desc) + t.FailNow() + } +} + func expectSpecificError(expected error) expectedErrorFunc { return func(t *testing.T, err error, desc string) { if !errors.Is(err, expected) { diff --git a/bson/json_scanner_test.go b/bson/json_scanner_test.go index 58f6e64594..b46d5aac9c 100644 --- a/bson/json_scanner_test.go +++ b/bson/json_scanner_test.go @@ -12,6 +12,7 @@ import ( "testing/iotest" "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/v2/internal/require" ) func jttDiff(t *testing.T, expected, actual jsonTokenType, desc string) { @@ -289,7 +290,7 @@ func TestJsonScannerValidInputs(t *testing.T) { for _, token := range tc.tokens { c, err := js.nextToken() - expectNoError(t, err, tc.desc) + require.NoError(t, err, tc.desc) jttDiff(t, token.t, c.t, tc.desc) jtvDiff(t, token.v, c.v, tc.desc) } @@ -303,7 +304,7 @@ func TestJsonScannerValidInputs(t *testing.T) { for _, token := range tc.tokens { c, err := js.nextToken() - expectNoError(t, err, tc.desc) + require.NoError(t, err, tc.desc) jttDiff(t, token.t, c.t, tc.desc) jtvDiff(t, token.v, c.v, tc.desc) } @@ -354,7 +355,7 @@ func TestJsonScannerInvalidInputs(t *testing.T) { c, err := js.nextToken() expectNilToken(t, c, tc.desc) - expectError(t, err, tc.desc) + require.Error(t, err, tc.desc) }) } } diff --git a/bson/types.go b/bson/types.go index dedc95a596..c2883aa4ef 100644 --- a/bson/types.go +++ b/bson/types.go @@ -72,6 +72,7 @@ const ( TypeBinaryEncrypted byte = 0x06 TypeBinaryColumn byte = 0x07 TypeBinarySensitive byte = 0x08 + TypeBinaryVector byte = 0x09 TypeBinaryUserDefined byte = 0x80 ) @@ -106,6 +107,7 @@ var tJavaScript = reflect.TypeOf(JavaScript("")) var tSymbol = reflect.TypeOf(Symbol("")) var tTimestamp = reflect.TypeOf(Timestamp{}) var tDecimal = reflect.TypeOf(Decimal128{}) +var tVector = reflect.TypeOf(Vector{}) var tMinKey = reflect.TypeOf(MinKey{}) var tMaxKey = reflect.TypeOf(MaxKey{}) var tD = reflect.TypeOf(D{}) diff --git a/bson/vector.go b/bson/vector.go new file mode 100644 index 0000000000..31a10bd5be --- /dev/null +++ b/bson/vector.go @@ -0,0 +1,268 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +// BSON binary vector types as described in https://bsonspec.org/spec.html. +const ( + Int8Vector byte = 0x03 + Float32Vector byte = 0x27 + PackedBitVector byte = 0x10 +) + +// These are vector conversion errors. +var ( + errInsufficientVectorData = errors.New("insufficient data") + errNonZeroVectorPadding = errors.New("padding must be 0") + errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7") +) + +type vectorTypeError struct { + Method string + Type byte +} + +// Error implements the error interface. +func (vte vectorTypeError) Error() string { + t := "invalid" + switch vte.Type { + case Int8Vector: + t = "int8" + case Float32Vector: + t = "float32" + case PackedBitVector: + t = "packed bit" + } + return fmt.Sprintf("cannot call %s, on a type %s vector", vte.Method, t) +} + +// Vector represents a densely packed array of numbers / bits. +type Vector struct { + dType byte + int8Data []int8 + float32Data []float32 + bitData []byte + bitPadding uint8 +} + +// Type returns the vector type. +func (v Vector) Type() byte { + return v.dType +} + +// Int8 returns the int8 slice hold by the vector. +// It panics if v is not an int8 vector. +func (v Vector) Int8() []int8 { + d, ok := v.Int8OK() + if !ok { + panic(vectorTypeError{"bson.Vector.Int8", v.dType}) + } + return d +} + +// Int8OK is the same as Int8, but returns a boolean instead of panicking. +func (v Vector) Int8OK() ([]int8, bool) { + if v.dType != Int8Vector { + return nil, false + } + return v.int8Data, true +} + +// Float32 returns the float32 slice hold by the vector. +// It panics if v is not a float32 vector. +func (v Vector) Float32() []float32 { + d, ok := v.Float32OK() + if !ok { + panic(vectorTypeError{"bson.Vector.Float32", v.dType}) + } + return d +} + +// Float32OK is the same as Float32, but returns a boolean instead of panicking. +func (v Vector) Float32OK() ([]float32, bool) { + if v.dType != Float32Vector { + return nil, false + } + return v.float32Data, true +} + +// PackedBit returns the byte slice representing the binary quantized (packed bit) vector and the byte padding, which +// is the number of bits in the final byte that are to be ignored. +// It panics if v is not a packed bit vector. +func (v Vector) PackedBit() ([]byte, uint8) { + d, p, ok := v.PackedBitOK() + if !ok { + panic(vectorTypeError{"bson.Vector.PackedBit", v.dType}) + } + return d, p +} + +// PackedBitOK is the same as PackedBit, but returns a boolean instead of panicking. +func (v Vector) PackedBitOK() ([]byte, uint8, bool) { + if v.dType != PackedBitVector { + return nil, 0, false + } + return v.bitData, v.bitPadding, true +} + +// Binary returns the BSON Binary representation of the Vector. +func (v Vector) Binary() Binary { + switch v.Type() { + case Int8Vector: + return binaryFromInt8Vector(v.Int8()) + case Float32Vector: + return binaryFromFloat32Vector(v.Float32()) + case PackedBitVector: + return binaryFromBitVector(v.PackedBit()) + default: + panic(fmt.Sprintf("invalid Vector data type: %d", v.dType)) + } +} + +func binaryFromInt8Vector(v []int8) Binary { + data := make([]byte, len(v)+2) + data[0] = Int8Vector + data[1] = 0 + + for i, e := range v { + data[i+2] = byte(e) + } + + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + } +} + +func binaryFromFloat32Vector(v []float32) Binary { + data := make([]byte, 2, len(v)*4+2) + data[0] = Float32Vector + data[1] = 0 + var a [4]byte + for _, e := range v { + binary.LittleEndian.PutUint32(a[:], math.Float32bits(e)) + data = append(data, a[:]...) + } + + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + } +} + +func binaryFromBitVector(bits []byte, padding uint8) Binary { + data := make([]byte, len(bits)+2) + data[0] = PackedBitVector + data[1] = padding + copy(data[2:], bits) + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + } +} + +// NewVector constructs a Vector from a slice of int8 or float32. +func NewVector[T int8 | float32](data []T) Vector { + var v Vector + switch a := any(data).(type) { + case []int8: + v.dType = Int8Vector + v.int8Data = make([]int8, len(data)) + copy(v.int8Data, a) + case []float32: + v.dType = Float32Vector + v.float32Data = make([]float32, len(data)) + copy(v.float32Data, a) + default: + panic(fmt.Errorf("unsupported type %T", data)) + } + return v +} + +// NewPackedBitVector constructs a Vector from a byte slice and a value of byte padding. +func NewPackedBitVector(bits []byte, padding uint8) (Vector, error) { + var v Vector + if padding > 7 { + return v, errVectorPaddingTooLarge + } + if padding > 0 && len(bits) == 0 { + return v, errNonZeroVectorPadding + } + v.dType = PackedBitVector + v.bitData = make([]byte, len(bits)) + copy(v.bitData, bits) + v.bitPadding = padding + return v, nil +} + +// NewVectorFromBinary unpacks a BSON Binary into a Vector. +func NewVectorFromBinary(b Binary) (Vector, error) { + var v Vector + if b.Subtype != TypeBinaryVector { + return v, errors.New("not a vector") + } + if len(b.Data) < 2 { + return v, errInsufficientVectorData + } + switch t := b.Data[0]; t { + case Int8Vector: + return newInt8Vector(b.Data[1:]) + case Float32Vector: + return newFloat32Vector(b.Data[1:]) + case PackedBitVector: + return newBitVector(b.Data[1:]) + default: + return v, fmt.Errorf("invalid Vector data type: %d", t) + } +} + +func newInt8Vector(b []byte) (Vector, error) { + var v Vector + if len(b) == 0 { + return v, errInsufficientVectorData + } + if padding := b[0]; padding > 0 { + return v, errNonZeroVectorPadding + } + s := make([]int8, 0, len(b)-1) + for i := 1; i < len(b); i++ { + s = append(s, int8(b[i])) + } + return NewVector(s), nil +} + +func newFloat32Vector(b []byte) (Vector, error) { + var v Vector + if len(b) == 0 { + return v, errInsufficientVectorData + } + if padding := b[0]; padding > 0 { + return v, errNonZeroVectorPadding + } + l := (len(b) - 1) / 4 + if l*4 != len(b)-1 { + return v, errInsufficientVectorData + } + s := make([]float32, 0, l) + for i := 1; i < len(b); i += 4 { + s = append(s, math.Float32frombits(binary.LittleEndian.Uint32(b[i:i+4]))) + } + return NewVector(s), nil +} + +func newBitVector(b []byte) (Vector, error) { + if len(b) == 0 { + return Vector{}, errInsufficientVectorData + } + return NewPackedBitVector(b[1:], b[0]) +} diff --git a/testdata/bson-binary-vector/float32.json b/testdata/bson-binary-vector/float32.json new file mode 100644 index 0000000000..d423f9e2bd --- /dev/null +++ b/testdata/bson-binary-vector/float32.json @@ -0,0 +1,50 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector FLOAT32", + "valid": true, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" + }, + { + "description": "Vector with decimals and negative value FLOAT32", + "valid": true, + "vector": [127.7, -7.7], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927006666FF426666F6C000" + }, + { + "description": "Empty Vector FLOAT32", + "valid": true, + "vector": [], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009270000" + }, + { + "description": "Infinity Vector FLOAT32", + "valid": true, + "vector": ["-inf", 0.0, "inf"], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00" + }, + { + "description": "FLOAT32 with padding", + "valid": false, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 3 + } + ] +} diff --git a/testdata/bson-binary-vector/int8.json b/testdata/bson-binary-vector/int8.json new file mode 100644 index 0000000000..d849819992 --- /dev/null +++ b/testdata/bson-binary-vector/int8.json @@ -0,0 +1,56 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype INT8", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector INT8", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000903007F0700" + }, + { + "description": "Empty Vector INT8", + "valid": true, + "vector": [], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009030000" + }, + { + "description": "Overflow Vector INT8", + "valid": false, + "vector": [128], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "Underflow Vector INT8", + "valid": false, + "vector": [-129], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "INT8 with padding", + "valid": false, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 3 + }, + { + "description": "INT8 with float inputs", + "valid": false, + "vector": [127.77, 7.77], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + } + ] +} diff --git a/testdata/bson-binary-vector/packed_bit.json b/testdata/bson-binary-vector/packed_bit.json new file mode 100644 index 0000000000..0d5dae52b4 --- /dev/null +++ b/testdata/bson-binary-vector/packed_bit.json @@ -0,0 +1,97 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", + "test_key": "vector", + "tests": [ + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Simple Vector PACKED_BIT", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000910007F0700" + }, + { + "description": "Empty Vector PACKED_BIT", + "valid": true, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009100000" + }, + { + "description": "PACKED_BIT with padding", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000910037F0700" + }, + { + "description": "Overflow Vector PACKED_BIT", + "valid": false, + "vector": [256], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Underflow Vector PACKED_BIT", + "valid": false, + "vector": [-1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Exceeding maximum padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 8 + }, + { + "description": "Negative padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": -1 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + } + ] +} diff --git a/testdata/bson-corpus/binary.json b/testdata/bson-corpus/binary.json index 20aaef743b..0e0056f3a2 100644 --- a/testdata/bson-corpus/binary.json +++ b/testdata/bson-corpus/binary.json @@ -74,6 +74,36 @@ "description": "$type query operator (conflicts with legacy $binary form with $type field)", "canonical_bson": "180000000378001000000010247479706500020000000000", "canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}" + }, + { + "description": "subtype 0x09 Vector FLOAT32", + "canonical_bson": "170000000578000A0000000927000000FE420000E04000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector INT8", + "canonical_bson": "11000000057800040000000903007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector PACKED_BIT", + "canonical_bson": "11000000057800040000000910007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) FLOAT32", + "canonical_bson": "0F0000000578000200000009270000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) INT8", + "canonical_bson": "0F0000000578000200000009030000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) PACKED_BIT", + "canonical_bson": "0F0000000578000200000009100000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}" } ], "decodeErrors": [