Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions bson/bson_binary_vector_spec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package bson

import (
"encoding/hex"
"encoding/json"
"fmt"
"math"
"os"
"path"
"testing"

"go.mongodb.org/mongo-driver/v2/internal/require"
)

const bsonBinaryVectorDir = "../testdata/bson-binary-vector/"

type bsonBinaryVectorTests struct {
Description string `json:"description"`
TestKey string `json:"test_key"`
Tests []bsonBinaryVectorTestCase `json:"tests"`
}

type bsonBinaryVectorTestCase struct {
Description string `json:"description"`
Valid bool `json:"valid"`
Vector []interface{} `json:"vector"`
DtypeHex string `json:"dtype_hex"`
DtypeAlias string `json:"dtype_alias"`
Padding int `json:"padding"`
CanonicalBson string `json:"canonical_bson"`
}

func TestBsonBinaryVector(t *testing.T) {
t.Parallel()

jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)

for _, file := range jsonFiles {
filepath := path.Join(bsonBinaryVectorDir, file)
content, err := os.ReadFile(filepath)
require.NoErrorf(t, err, "reading test file %s", filepath)

var tests bsonBinaryVectorTests
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)

t.Run(tests.Description, func(t *testing.T) {
t.Parallel()

for _, test := range tests.Tests {
test := test
t.Run(test.Description, func(t *testing.T) {
t.Parallel()

runBsonBinaryVectorTest(t, tests.TestKey, test)
})
}
})
}

t.Run("Insufficient vector data FLOAT32", func(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making non-spec tests their own test.

func TestInsufficientVectorDataFloat32(*testing.T)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, these should be included in the spec tests. I will submit a spec PR later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a pattern for naming specification tests Test{Focus}Spec. For example, TestConnStringSpec, TestURIOptionsSpec, TestPollingSRVRecordsSpec, TestServerSelectionRTTSpec, etc. We should call this test TestBsonBinaryVectorSpec and decouple it from non-spec tests. If we are going to make these subtests spec tests in the future, we should at least add a comment with the corresponding DRIVERS ticket so that they can be removed as duplicates after implementation.

t.Parallel()

val := Binary{Subtype: TypeBinaryVector}

for _, tc := range [][]byte{
{Float32Vector, 0, 42},
{Float32Vector, 0, 42, 42},
{Float32Vector, 0, 42, 42, 42},

{Float32Vector, 0, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
} {
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
val.Data = tc
b, err := Marshal(D{{"vector", val}})
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errInsufficientVectorData.Error())
})
}
})

t.Run("FLOAT32 with padding", func(t *testing.T) {
t.Parallel()

t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Float32Vector, 3}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("INT8 with padding", func(t *testing.T) {
t.Parallel()

t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Int8Vector, 3}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
_, err := NewPackedBitVector(nil, 1)
require.EqualError(t, err, errNonZeroVectorPadding.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
_, err := NewPackedBitVector(nil, 8)
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errVectorPaddingTooLarge.Error())
})
})
}

func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
v := make([]T, len(s))
for i, e := range s {
f := math.NaN()
switch val := e.(type) {
case float64:
f = val
case string:
if val == "inf" {
f = math.Inf(0)
} else if val == "-inf" {
f = math.Inf(-1)
}
}
v[i] = T(f)
}
return v
}

func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
testVector := make(map[string]Vector)
switch alias := test.DtypeHex; alias {
case "0x03":
testVector[testKey] = Vector{
dType: Int8Vector,
int8Data: convertSlice[int8](test.Vector),
}
case "0x27":
testVector[testKey] = Vector{
dType: Float32Vector,
float32Data: convertSlice[float32](test.Vector),
}
case "0x10":
testVector[testKey] = Vector{
dType: PackedBitVector,
bitData: convertSlice[byte](test.Vector),
bitPadding: uint8(test.Padding),
}
default:
t.Fatalf("unsupported vector type: %s", alias)
}

testBSON, err := hex.DecodeString(test.CanonicalBson)
require.NoError(t, err, "decoding canonical BSON")

t.Run("Unmarshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "run in alternative case",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "run in alternative case",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}

t.Parallel()

var got map[string]Vector
err := Unmarshal(testBSON, &got)
require.NoError(t, err)
require.Equal(t, testVector, got)
})

t.Run("Marshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "private padding field",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "private padding field",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}

t.Parallel()

got, err := Marshal(testVector)
require.NoError(t, err)
require.Equal(t, testBSON, got)
})
}
38 changes: 10 additions & 28 deletions bson/bson_corpus_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType))
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
return doc
}

// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType))
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)

if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
Expand Down Expand Up @@ -261,7 +261,7 @@ func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType))
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)

if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
Expand All @@ -288,7 +288,7 @@ func runTest(t *testing.T, file string) {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description))
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)

// get canonical extended JSON
var compactEJ bytes.Buffer
Expand Down Expand Up @@ -341,7 +341,7 @@ func runTest(t *testing.T, file string) {
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description))
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)

doc = bsonToNative(t, dB, "degenerate", v.Description)

Expand Down Expand Up @@ -377,7 +377,7 @@ func runTest(t *testing.T, file string) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
expectNoError(t, err, d.Description)
require.NoError(t, err, d.Description)

var doc D
err = Unmarshal(b, &doc)
Expand All @@ -392,12 +392,12 @@ func runTest(t *testing.T, file string) {
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)

if invalidString || invalidDBPtr {
expectNoError(t, err, d.Description)
require.NoError(t, err, d.Description)
return
}
}

expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description))
require.Errorf(t, err, "%s: expected decode error", d.Description)
})
}
})
Expand All @@ -418,7 +418,7 @@ func runTest(t *testing.T, file string) {
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
require.Errorf(t, err, "%s: expected parse error", p.Description)
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
Expand All @@ -431,31 +431,13 @@ func runTest(t *testing.T, file string) {

func Test_BsonCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
if err != nil {
t.Fatalf("error finding JSON files in %s: %v", dataDir, err)
}
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)

for _, file := range jsonFiles {
runTest(t, file)
}
}

func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}

func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}

func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
Expand Down
Loading
Loading