Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions bson/bson_binary_vector_spec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package bson

import (
"encoding/hex"
"encoding/json"
"fmt"
"math"
"os"
"path"
"testing"

"go.mongodb.org/mongo-driver/v2/internal/require"
)

const bsonBinaryVectorDir = "../testdata/bson-binary-vector/"

type bsonBinaryVectorTests struct {
Description string `json:"description"`
TestKey string `json:"test_key"`
Tests []bsonBinaryVectorTestCase `json:"tests"`
}

type bsonBinaryVectorTestCase struct {
Description string `json:"description"`
Valid bool `json:"valid"`
Vector []interface{} `json:"vector"`
DtypeHex string `json:"dtype_hex"`
DtypeAlias string `json:"dtype_alias"`
Padding int `json:"padding"`
CanonicalBson string `json:"canonical_bson"`
}

func TestBsonBinaryVectorSpec(t *testing.T) {
t.Parallel()

jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)

for _, file := range jsonFiles {
filepath := path.Join(bsonBinaryVectorDir, file)
content, err := os.ReadFile(filepath)
require.NoErrorf(t, err, "reading test file %s", filepath)

var tests bsonBinaryVectorTests
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)

t.Run(tests.Description, func(t *testing.T) {
t.Parallel()

for _, test := range tests.Tests {
test := test
t.Run(test.Description, func(t *testing.T) {
t.Parallel()

runBsonBinaryVectorTest(t, tests.TestKey, test)
})
}
})
}

t.Run("FLOAT32 with padding", func(t *testing.T) {
t.Parallel()

t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Float32Vector, 3}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("INT8 with padding", func(t *testing.T) {
t.Parallel()

t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Int8Vector, 3}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
_, err := NewPackedBitVector(nil, 1)
require.EqualError(t, err, errNonZeroVectorPadding.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
_, err := NewPackedBitVector(nil, 8)
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errVectorPaddingTooLarge.Error())
})
})
}

// TODO: This test may be added into the spec tests.
func TestFloat32VectorWithInsufficientData(t *testing.T) {
t.Parallel()

val := Binary{Subtype: TypeBinaryVector}

for _, tc := range [][]byte{
{Float32Vector, 0, 42},
{Float32Vector, 0, 42, 42},
{Float32Vector, 0, 42, 42, 42},

{Float32Vector, 0, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
} {
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
val.Data = tc
b, err := Marshal(D{{"vector", val}})
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, errInsufficientVectorData.Error())
})
}
}

func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
v := make([]T, len(s))
for i, e := range s {
f := math.NaN()
switch val := e.(type) {
case float64:
f = val
case string:
if val == "inf" {
f = math.Inf(0)
} else if val == "-inf" {
f = math.Inf(-1)
}
}
v[i] = T(f)
}
return v
}

func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
testVector := make(map[string]Vector)
switch alias := test.DtypeHex; alias {
case "0x03":
testVector[testKey] = Vector{
dType: Int8Vector,
int8Data: convertSlice[int8](test.Vector),
}
case "0x27":
testVector[testKey] = Vector{
dType: Float32Vector,
float32Data: convertSlice[float32](test.Vector),
}
case "0x10":
testVector[testKey] = Vector{
dType: PackedBitVector,
bitData: convertSlice[byte](test.Vector),
bitPadding: uint8(test.Padding),
}
default:
t.Fatalf("unsupported vector type: %s", alias)
}

testBSON, err := hex.DecodeString(test.CanonicalBson)
require.NoError(t, err, "decoding canonical BSON")

t.Run("Unmarshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "run in alternative case",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "run in alternative case",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}

t.Parallel()

var got map[string]Vector
err := Unmarshal(testBSON, &got)
require.NoError(t, err)
require.Equal(t, testVector, got)
})

t.Run("Marshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "private padding field",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "private padding field",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}

t.Parallel()

got, err := Marshal(testVector)
require.NoError(t, err)
require.Equal(t, testBSON, got)
})
}
38 changes: 10 additions & 28 deletions bson/bson_corpus_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType))
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
return doc
}

// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType))
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)

if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
Expand Down Expand Up @@ -261,7 +261,7 @@ func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType))
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)

if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
Expand All @@ -288,7 +288,7 @@ func runTest(t *testing.T, file string) {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description))
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)

// get canonical extended JSON
var compactEJ bytes.Buffer
Expand Down Expand Up @@ -341,7 +341,7 @@ func runTest(t *testing.T, file string) {
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description))
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)

doc = bsonToNative(t, dB, "degenerate", v.Description)

Expand Down Expand Up @@ -377,7 +377,7 @@ func runTest(t *testing.T, file string) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
expectNoError(t, err, d.Description)
require.NoError(t, err, d.Description)

var doc D
err = Unmarshal(b, &doc)
Expand All @@ -392,12 +392,12 @@ func runTest(t *testing.T, file string) {
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)

if invalidString || invalidDBPtr {
expectNoError(t, err, d.Description)
require.NoError(t, err, d.Description)
return
}
}

expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description))
require.Errorf(t, err, "%s: expected decode error", d.Description)
})
}
})
Expand All @@ -418,7 +418,7 @@ func runTest(t *testing.T, file string) {
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
require.Errorf(t, err, "%s: expected parse error", p.Description)
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
Expand All @@ -431,31 +431,13 @@ func runTest(t *testing.T, file string) {

func Test_BsonCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
if err != nil {
t.Fatalf("error finding JSON files in %s: %v", dataDir, err)
}
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)

for _, file := range jsonFiles {
runTest(t, file)
}
}

func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}

func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}

func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
Expand Down
Loading
Loading