Skip to content

Commit a869411

Browse files
committed
roachtest/vecindex: test recall of vector ann data
This addition to the vecindex roachtest tests the recall of nearest neighbors from the test data provided with each data set. Each test has a configurable set of beam sizes to test and a minimum recall correctneess that is acceptable for each beam size. Tests that load multiple prefixes test each prefix. Informs: cockroachdb#154590 Release note: None
1 parent 992e202 commit a869411

File tree

2 files changed

+96
-3
lines changed

2 files changed

+96
-3
lines changed

pkg/cmd/roachtest/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ go_library(
285285
"//pkg/sql/pgwire/pgerror",
286286
"//pkg/sql/sem/tree",
287287
"//pkg/sql/ttl/ttlbase",
288+
"//pkg/sql/vecindex/cspann",
288289
"//pkg/sql/vecindex/vecpb",
289290
"//pkg/storage/enginepb",
290291
"//pkg/storage/fs",

pkg/cmd/roachtest/tests/vecindex.go

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/cockroachdb/cockroach/pkg/roachprod"
2323
"github.com/cockroachdb/cockroach/pkg/roachprod/install"
2424
"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
25+
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann"
2526
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
2627
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
2728
"github.com/cockroachdb/cockroach/pkg/util/vector"
@@ -96,14 +97,18 @@ func getOpClass(metric vecpb.DistanceMetric) string {
9697
}
9798
}
9899

99-
// makeCanonicalKey generates 4-byte primary key from dataset index
100-
func makeCanonicalKey(category int, datasetIdx int) []byte {
101-
key := make([]byte, 0, 6)
100+
func appendCanonicalKey(key []byte, category int, datasetIdx int) []byte {
102101
key = binary.BigEndian.AppendUint16(key, uint16(category))
103102
key = binary.BigEndian.AppendUint32(key, uint32(datasetIdx))
104103
return key
105104
}
106105

106+
// makeCanonicalKey generates 4-byte primary key from dataset index
107+
func makeCanonicalKey(category int, datasetIdx int) []byte {
108+
key := make([]byte, 0, 6)
109+
return appendCanonicalKey(key, category, datasetIdx)
110+
}
111+
107112
// backfillState represents the current state of index backfill
108113
type backfillState int32
109114

@@ -276,6 +281,9 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
276281

277282
t.L().Printf("Creating schema and loading data")
278283
testBackfillAndMerge(ctx, t, c, pool, &loader.Data, &opts, metric)
284+
285+
t.L().Printf("Testing recall of loaded data")
286+
testRecall(ctx, t, pool, &loader.Data, &opts, metric)
279287
}
280288

281289
func testBackfillAndMerge(
@@ -396,12 +404,96 @@ func testBackfillAndMerge(
396404
}
397405
// Wait for this batch of loaders
398406
m.Wait()
407+
fileStart += data.Train.Count
399408
}
400409

401410
// Wait for create index to finish
402411
ci.Wait()
403412
}
404413

414+
func testRecall(
415+
ctx context.Context,
416+
t test.Test,
417+
pool *pgxpool.Pool,
418+
data *vecann.Dataset,
419+
opts *vecIndexOptions,
420+
metric vecpb.DistanceMetric,
421+
) {
422+
conn, err := pool.Acquire(ctx)
423+
require.NoError(t, err)
424+
defer conn.Release()
425+
426+
maxResults := 10
427+
operator := getOperatorForMetric(metric)
428+
429+
var categories int
430+
var searchSQL string
431+
var args []any
432+
var hasPrefix bool
433+
if opts.prefixCount > 0 {
434+
categories = opts.prefixCount
435+
searchSQL = fmt.Sprintf(
436+
"SELECT id FROM vecindex_test@vecidx WHERE category = $1 "+
437+
"ORDER BY embedding %s $2 LIMIT %d", operator, maxResults)
438+
args = make([]any, 2)
439+
hasPrefix = true
440+
} else {
441+
categories = 1
442+
searchSQL = fmt.Sprintf(
443+
"SELECT id FROM vecindex_test@vecidx ORDER BY embedding %s $1 LIMIT %d", operator, maxResults)
444+
args = make([]any, 1)
445+
}
446+
447+
results := make([]cspann.KeyBytes, maxResults)
448+
recalls := make([]float64, categories)
449+
primaryKeys := make([]byte, maxResults*6)
450+
truth := make([]cspann.KeyBytes, maxResults)
451+
452+
for i, beamSize := range opts.beamSizes {
453+
recalls = recalls[:0]
454+
minRecall := opts.minRecall[i]
455+
_, err = conn.Exec(ctx, fmt.Sprintf("SET vector_search_beam_size = %d", beamSize))
456+
require.NoError(t, err)
457+
458+
for cat := range categories {
459+
var sumRecall float64
460+
for j := range data.Test.Count {
461+
queryVec := data.Test.At(j)
462+
args = args[:0]
463+
if hasPrefix {
464+
args = append(args, cat)
465+
}
466+
args = append(args, queryVec)
467+
468+
rows, err := conn.Query(ctx, searchSQL, args...)
469+
require.NoError(t, err)
470+
471+
results = results[:0]
472+
for rows.Next() {
473+
var id []byte
474+
err = rows.Scan(&id)
475+
require.NoError(t, err)
476+
results = append(results, id)
477+
}
478+
require.NoError(t, rows.Err())
479+
480+
primaryKeys = primaryKeys[:0]
481+
truth = truth[:0]
482+
for n := range maxResults {
483+
primaryKeys = appendCanonicalKey(primaryKeys, cat, int(data.Neighbors[j][n]))
484+
truth = append(truth, primaryKeys[len(primaryKeys)-6:])
485+
}
486+
487+
sumRecall += vecann.CalculateRecall(results, truth)
488+
}
489+
avgRecall := sumRecall / float64(data.Test.Count)
490+
require.GreaterOrEqualf(t, avgRecall, minRecall, "at beam size %d", beamSize)
491+
recalls = append(recalls, avgRecall)
492+
}
493+
t.L().Printf("beam size=%d : %v", beamSize, recalls)
494+
}
495+
}
496+
405497
func insertVectors(
406498
ctx context.Context,
407499
conn *pgxpool.Conn,

0 commit comments

Comments
 (0)