@@ -22,6 +22,7 @@ import (
2222 "github.com/cockroachdb/cockroach/pkg/roachprod"
2323 "github.com/cockroachdb/cockroach/pkg/roachprod/install"
2424 "github.com/cockroachdb/cockroach/pkg/roachprod/logger"
25+ "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann"
2526 "github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
2627 "github.com/cockroachdb/cockroach/pkg/util/timeutil"
2728 "github.com/cockroachdb/cockroach/pkg/util/vector"
@@ -96,14 +97,18 @@ func getOpClass(metric vecpb.DistanceMetric) string {
9697 }
9798}
9899
99- // makeCanonicalKey generates 4-byte primary key from dataset index
100- func makeCanonicalKey (category int , datasetIdx int ) []byte {
101- key := make ([]byte , 0 , 6 )
100+ func appendCanonicalKey (key []byte , category int , datasetIdx int ) []byte {
102101 key = binary .BigEndian .AppendUint16 (key , uint16 (category ))
103102 key = binary .BigEndian .AppendUint32 (key , uint32 (datasetIdx ))
104103 return key
105104}
106105
106+ // makeCanonicalKey generates 4-byte primary key from dataset index
107+ func makeCanonicalKey (category int , datasetIdx int ) []byte {
108+ key := make ([]byte , 0 , 6 )
109+ return appendCanonicalKey (key , category , datasetIdx )
110+ }
111+
107112// backfillState represents the current state of index backfill
108113type backfillState int32
109114
@@ -276,6 +281,9 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
276281
277282 t .L ().Printf ("Creating schema and loading data" )
278283 testBackfillAndMerge (ctx , t , c , pool , & loader .Data , & opts , metric )
284+
285+ t .L ().Printf ("Testing recall of loaded data" )
286+ testRecall (ctx , t , pool , & loader .Data , & opts , metric )
279287}
280288
281289func testBackfillAndMerge (
@@ -396,12 +404,96 @@ func testBackfillAndMerge(
396404 }
397405 // Wait for this batch of loaders
398406 m .Wait ()
407+ fileStart += data .Train .Count
399408 }
400409
401410 // Wait for create index to finish
402411 ci .Wait ()
403412}
404413
414+ func testRecall (
415+ ctx context.Context ,
416+ t test.Test ,
417+ pool * pgxpool.Pool ,
418+ data * vecann.Dataset ,
419+ opts * vecIndexOptions ,
420+ metric vecpb.DistanceMetric ,
421+ ) {
422+ conn , err := pool .Acquire (ctx )
423+ require .NoError (t , err )
424+ defer conn .Release ()
425+
426+ maxResults := 10
427+ operator := getOperatorForMetric (metric )
428+
429+ var categories int
430+ var searchSQL string
431+ var args []any
432+ var hasPrefix bool
433+ if opts .prefixCount > 0 {
434+ categories = opts .prefixCount
435+ searchSQL = fmt .Sprintf (
436+ "SELECT id FROM vecindex_test@vecidx WHERE category = $1 " +
437+ "ORDER BY embedding %s $2 LIMIT %d" , operator , maxResults )
438+ args = make ([]any , 2 )
439+ hasPrefix = true
440+ } else {
441+ categories = 1
442+ searchSQL = fmt .Sprintf (
443+ "SELECT id FROM vecindex_test@vecidx ORDER BY embedding %s $1 LIMIT %d" , operator , maxResults )
444+ args = make ([]any , 1 )
445+ }
446+
447+ results := make ([]cspann.KeyBytes , maxResults )
448+ recalls := make ([]float64 , categories )
449+ primaryKeys := make ([]byte , maxResults * 6 )
450+ truth := make ([]cspann.KeyBytes , maxResults )
451+
452+ for i , beamSize := range opts .beamSizes {
453+ recalls = recalls [:0 ]
454+ minRecall := opts .minRecall [i ]
455+ _ , err = conn .Exec (ctx , fmt .Sprintf ("SET vector_search_beam_size = %d" , beamSize ))
456+ require .NoError (t , err )
457+
458+ for cat := range categories {
459+ var sumRecall float64
460+ for j := range data .Test .Count {
461+ queryVec := data .Test .At (j )
462+ args = args [:0 ]
463+ if hasPrefix {
464+ args = append (args , cat )
465+ }
466+ args = append (args , queryVec )
467+
468+ rows , err := conn .Query (ctx , searchSQL , args ... )
469+ require .NoError (t , err )
470+
471+ results = results [:0 ]
472+ for rows .Next () {
473+ var id []byte
474+ err = rows .Scan (& id )
475+ require .NoError (t , err )
476+ results = append (results , id )
477+ }
478+ require .NoError (t , rows .Err ())
479+
480+ primaryKeys = primaryKeys [:0 ]
481+ truth = truth [:0 ]
482+ for n := range maxResults {
483+ primaryKeys = appendCanonicalKey (primaryKeys , cat , int (data.Neighbors [j ][n ]))
484+ truth = append (truth , primaryKeys [len (primaryKeys )- 6 :])
485+ }
486+
487+ sumRecall += vecann .CalculateRecall (results , truth )
488+ }
489+ avgRecall := sumRecall / float64 (data .Test .Count )
490+ require .GreaterOrEqualf (t , avgRecall , minRecall , "at beam size %d" , beamSize )
491+ recalls = append (recalls , avgRecall )
492+ }
493+ t .L ().Printf ("beam size=%d : %v" , beamSize , recalls )
494+ }
495+ }
496+
405497func insertVectors (
406498 ctx context.Context ,
407499 conn * pgxpool.Conn ,
0 commit comments