Skip to content

Commit 43b1a3a

Browse files
committed
vecindex: add datadriven tests for vecindex
Refactor the C-SPANN datadriven tests so they can be shared by other Store implementations. Add new datadriven test for vecindex that tests recall for several interesting datasets. Epic: CRDB-42943 Release note: None
1 parent f9e3388 commit 43b1a3a

File tree

15 files changed

+1017
-501
lines changed

15 files changed

+1017
-501
lines changed

pkg/cmd/vecbench/mem_provider.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ func (m *MemProvider) Search(
177177

178178
// Get result keys.
179179
results := searchSet.PopResults()
180+
if len(results) > memState.maxResults {
181+
results = results[:memState.maxResults]
182+
}
180183
keys = make([]cspann.KeyBytes, len(results))
181184
for i, res := range results {
182185
keys[i] = []byte(res.ChildKey.KeyBytes)

pkg/sql/vecindex/BUILD.bazel

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ go_library(
3232
"//pkg/sql/vecindex/vecpb",
3333
"//pkg/sql/vecindex/vecstore",
3434
"//pkg/sql/vecindex/vecstore/vecstorepb",
35-
"//pkg/util/buildutil",
3635
"//pkg/util/errorutil/unimplemented",
3736
"//pkg/util/log",
3837
"//pkg/util/metric",
@@ -51,7 +50,7 @@ go_test(
5150
"searcher_test.go",
5251
"vecindex_test.go",
5352
],
54-
data = ["//pkg/sql/vecindex/cspann:datasets"],
53+
data = ["//pkg/sql/vecindex/cspann:datasets"] + glob(["testdata/**"]),
5554
embed = [":vecindex"],
5655
exec_properties = select({
5756
"//build/toolchains:is_heavy": {"test.Pool": "large"},
@@ -83,6 +82,7 @@ go_test(
8382
"//pkg/sql/sem/tree",
8483
"//pkg/sql/types",
8584
"//pkg/sql/vecindex/cspann",
85+
"//pkg/sql/vecindex/cspann/commontest",
8686
"//pkg/sql/vecindex/cspann/quantize",
8787
"//pkg/sql/vecindex/cspann/testutils",
8888
"//pkg/sql/vecindex/vecencoding",
@@ -98,6 +98,7 @@ go_test(
9898
"//pkg/util/log",
9999
"//pkg/util/randutil",
100100
"//pkg/util/vector",
101+
"@com_github_cockroachdb_datadriven//:datadriven",
101102
"@com_github_stretchr_testify//require",
102103
],
103104
)

pkg/sql/vecindex/cspann/commontest/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
33
go_library(
44
name = "commontest",
55
srcs = [
6+
"indextests.go",
67
"storetests.go",
78
"utils.go",
89
],
@@ -16,6 +17,7 @@ go_library(
1617
"//pkg/sql/vecindex/vecpb",
1718
"//pkg/util/encoding",
1819
"//pkg/util/vector",
20+
"@com_github_cockroachdb_datadriven//:datadriven",
1921
"@com_github_cockroachdb_errors//:errors",
2022
"@com_github_stretchr_testify//require",
2123
"@com_github_stretchr_testify//suite",
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package commontest
7+
8+
import (
9+
"context"
10+
"fmt"
11+
"math/rand"
12+
"strings"
13+
"testing"
14+
"time"
15+
16+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann"
17+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/testutils"
18+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
19+
"github.com/cockroachdb/cockroach/pkg/util/vector"
20+
"github.com/cockroachdb/datadriven"
21+
"github.com/stretchr/testify/require"
22+
)
23+
24+
// TestIndex abstracts operations needed by datadriven index tests that use the
25+
// IndexTestState helper.
26+
type TestIndex interface {
27+
// MakeNewIndex returns a newly constructed index with the given options.
28+
MakeNewIndex(
29+
ctx context.Context, dims int, metric vecpb.DistanceMetric, options *cspann.IndexOptions,
30+
) *cspann.Index
31+
32+
// InsertVectors inserts the given set of vectors into the index. Each vector
33+
// is identified by a unique string key.
34+
InsertVectors(
35+
ctx context.Context, treeKey cspann.TreeKey, keys []string, vectors vector.Set,
36+
)
37+
38+
// SearchVectors searches the index for the query vector, returning the key
39+
// values of the top "k" nearest vectors.
40+
SearchVectors(
41+
ctx context.Context,
42+
treeKey cspann.TreeKey,
43+
queryVector vector.T,
44+
beamSize, topK, rerankMultiplier int,
45+
) []string
46+
}
47+
48+
// IndexTestState is a helper that constructs state used by index tests.
49+
type IndexTestState struct {
50+
T *testing.T
51+
Index *cspann.Index
52+
Dataset vector.Set
53+
TrainKeys []string
54+
55+
testIndex TestIndex
56+
}
57+
58+
// NewIndexTestState constructs a new IndexTestState for the given TestIndex.
59+
func NewIndexTestState(t *testing.T, testIndex TestIndex) *IndexTestState {
60+
return &IndexTestState{
61+
T: t,
62+
testIndex: testIndex,
63+
}
64+
}
65+
66+
// NewIndex runs the "new-index" command.
67+
func (s *IndexTestState) NewIndex(
68+
ctx context.Context, d *datadriven.TestData, treeKey cspann.TreeKey,
69+
) int {
70+
var err error
71+
dims := 0
72+
datasetName := ""
73+
trainCount := 0
74+
distanceMetric := vecpb.L2SquaredDistance
75+
options := cspann.IndexOptions{
76+
RotAlgorithm: vecpb.RotGivens,
77+
IsDeterministic: true,
78+
// Disable stalled op timeout, since it can interfere with stepping tests.
79+
StalledOpTimeout: func() time.Duration { return 0 },
80+
// Disable adaptive search for now, until it's fully supported for stores
81+
// other than the in-memory store.
82+
DisableAdaptiveSearch: true,
83+
}
84+
s.Dataset = vector.Set{}
85+
s.TrainKeys = nil
86+
87+
for _, arg := range d.CmdArgs {
88+
switch arg.Key {
89+
case "dataset":
90+
require.Len(s.T, arg.Vals, 1)
91+
datasetName = arg.Vals[0]
92+
93+
case "train-count":
94+
trainCount = testutils.ParseDataDrivenInt(s.T, arg)
95+
96+
case "distance-metric":
97+
require.Len(s.T, arg.Vals, 1)
98+
switch strings.ToLower(arg.Vals[0]) {
99+
case "innerproduct":
100+
distanceMetric = vecpb.InnerProductDistance
101+
case "cosine":
102+
distanceMetric = vecpb.CosineDistance
103+
}
104+
require.NoError(s.T, err)
105+
106+
case "rot-algorithm":
107+
require.Len(s.T, arg.Vals, 1)
108+
switch strings.ToLower(arg.Vals[0]) {
109+
case "matrix":
110+
options.RotAlgorithm = vecpb.RotMatrix
111+
case "givens":
112+
options.RotAlgorithm = vecpb.RotGivens
113+
case "none":
114+
options.RotAlgorithm = vecpb.RotNone
115+
default:
116+
require.Failf(s.T, "unrecognized rot algorithm %s", arg.Vals[0])
117+
}
118+
119+
case "min-partition-size":
120+
options.MinPartitionSize = testutils.ParseDataDrivenInt(s.T, arg)
121+
122+
case "max-partition-size":
123+
options.MaxPartitionSize = testutils.ParseDataDrivenInt(s.T, arg)
124+
125+
case "quality-samples":
126+
options.QualitySamples = testutils.ParseDataDrivenInt(s.T, arg)
127+
128+
case "dims":
129+
dims = testutils.ParseDataDrivenInt(s.T, arg)
130+
131+
case "beam-size":
132+
options.BaseBeamSize = testutils.ParseDataDrivenInt(s.T, arg)
133+
134+
case "read-only":
135+
options.ReadOnly = testutils.ParseDataDrivenFlag(s.T, arg)
136+
}
137+
}
138+
139+
if datasetName != "" {
140+
dataset := testutils.LoadDataset(s.T, datasetName)
141+
142+
if dims != 0 {
143+
// Trim dataset dimensions to make test run faster.
144+
s.Dataset = vector.MakeSet(min(dims, dataset.Dims))
145+
dims = s.Dataset.Dims
146+
for i := range dataset.Count {
147+
s.Dataset.Add(dataset.At(i)[:dims])
148+
}
149+
} else {
150+
s.Dataset = dataset
151+
dims = s.Dataset.Dims
152+
}
153+
} else if dims == 0 {
154+
// Default to 2 dimensions if not specified.
155+
dims = 2
156+
}
157+
158+
s.Index = s.testIndex.MakeNewIndex(ctx, dims, distanceMetric, &options)
159+
160+
if trainCount != 0 {
161+
// Insert train vectors into the index.
162+
vectors := s.Dataset.Slice(0, trainCount)
163+
s.TrainKeys = make([]string, 0, trainCount)
164+
for i := range trainCount {
165+
s.TrainKeys = append(s.TrainKeys, fmt.Sprintf("vec%d", i))
166+
}
167+
s.testIndex.InsertVectors(ctx, treeKey, s.TrainKeys, vectors)
168+
}
169+
170+
return trainCount
171+
}
172+
173+
// Insert runs the "insert" command.
174+
func (s *IndexTestState) Insert(
175+
ctx context.Context, d *datadriven.TestData, treeKey cspann.TreeKey,
176+
) int {
177+
var keys []string
178+
vectors := vector.MakeSet(s.Index.Quantizer().GetDims())
179+
180+
// Parse vectors.
181+
for _, line := range strings.Split(d.Input, "\n") {
182+
line = strings.TrimSpace(line)
183+
if len(line) == 0 {
184+
continue
185+
}
186+
parts := strings.Split(line, ":")
187+
require.Len(s.T, parts, 2)
188+
189+
vec, err := vector.ParseVector(parts[1])
190+
require.NoError(s.T, err)
191+
vectors.Add(vec)
192+
keys = append(keys, parts[0])
193+
}
194+
195+
s.testIndex.InsertVectors(ctx, treeKey, keys, vectors)
196+
197+
return vectors.Count
198+
}
199+
200+
// FormatTree runs the "format-tree" command.
201+
func (s *IndexTestState) FormatTree(
202+
ctx context.Context, d *datadriven.TestData, treeKey cspann.TreeKey,
203+
) string {
204+
var tree string
205+
RunTransaction(ctx, s.T, s.Index.Store(), func(txn cspann.Txn) {
206+
rootPartitionKey := cspann.RootKey
207+
for _, arg := range d.CmdArgs {
208+
switch arg.Key {
209+
case "root":
210+
rootPartitionKey = cspann.PartitionKey(testutils.ParseDataDrivenInt(s.T, arg))
211+
}
212+
}
213+
214+
var err error
215+
options := cspann.FormatOptions{PrimaryKeyStrings: true, RootPartitionKey: rootPartitionKey}
216+
tree, err = s.Index.Format(ctx, treeKey, options)
217+
require.NoError(s.T, err)
218+
})
219+
return tree
220+
}
221+
222+
// Recall runs the "rcall" command.
223+
func (s *IndexTestState) Recall(
224+
ctx context.Context, d *datadriven.TestData, treeKey cspann.TreeKey,
225+
) (topK, numSamples int, recall float64) {
226+
topK = 1
227+
numSamples = 50
228+
beamSize := 1
229+
rerankMultiplier := -1
230+
var samples []int
231+
seed := 42
232+
for _, arg := range d.CmdArgs {
233+
switch arg.Key {
234+
case "use-dataset":
235+
// Use single designated sample.
236+
offset := testutils.ParseDataDrivenInt(s.T, arg)
237+
numSamples = 1
238+
samples = []int{offset}
239+
240+
case "samples":
241+
numSamples = testutils.ParseDataDrivenInt(s.T, arg)
242+
243+
case "seed":
244+
seed = testutils.ParseDataDrivenInt(s.T, arg)
245+
246+
case "beam-size":
247+
beamSize = testutils.ParseDataDrivenInt(s.T, arg)
248+
249+
case "topk":
250+
topK = testutils.ParseDataDrivenInt(s.T, arg)
251+
252+
case "rerank-multiplier":
253+
rerankMultiplier = testutils.ParseDataDrivenInt(s.T, arg)
254+
}
255+
}
256+
257+
dataVectors := s.Dataset.Slice(0, len(s.TrainKeys))
258+
259+
// Construct random list of offsets into the test vectors in the dataset (i.e.
260+
// all vectors not part of the training set).
261+
if samples == nil {
262+
// Shuffle the remaining dataset vectors.
263+
rng := rand.New(rand.NewSource(int64(seed)))
264+
remaining := make([]int, s.Dataset.Count-len(s.TrainKeys))
265+
for i := range remaining {
266+
remaining[i] = len(s.TrainKeys) + i
267+
}
268+
rng.Shuffle(len(remaining), func(i, j int) {
269+
remaining[i], remaining[j] = remaining[j], remaining[i]
270+
})
271+
272+
// Pick numSamples randomly from the remaining set
273+
samples = make([]int, numSamples)
274+
copy(samples, remaining[:numSamples])
275+
}
276+
277+
// Search for sampled dataset vectors within a transaction.
278+
var sumRecall float64
279+
for i := range samples {
280+
// Calculate truth set for the vector.
281+
queryVector := s.Dataset.At(samples[i])
282+
283+
truth := testutils.CalculateTruth(
284+
topK, s.Index.Quantizer().GetDistanceMetric(), queryVector, dataVectors, s.TrainKeys)
285+
286+
prediction := s.testIndex.SearchVectors(
287+
ctx, treeKey, queryVector, beamSize, topK, rerankMultiplier)
288+
289+
sumRecall += testutils.CalculateRecall(prediction, truth)
290+
}
291+
292+
return topK, numSamples, sumRecall / float64(numSamples) * 100
293+
}

0 commit comments

Comments
 (0)