Skip to content

Commit 81f0490

Browse files
committed
roachtest/vecindex: add a concurrent reader/writer subtest
This subtest spins up a configurable number of readers and writers to drive vector search load to the database. Each writer inserts rows from the first train data file in single row batches until it has inserted all of the rows in that file, at which point it switches into delete mode and starts deleting rows in 10 row batches. When all rows have been deleted, the writer once again becomes an inserter and the process repeats. Each reader randomly selects a beam size from the sizes configured for the test and then runs searches for random vectors in the test data for the dataset. The reader ignores rows inserted by the writer threads to avoid too heavily skewing results. To do this, it searches for more vectors than called for and then filters the output to remove vectors written by the insert workers. When the read worker exits, it validates its recall rate against the expected rate for the number of searches it performed. For multi-prefix tests, this subtest only reads and writes to the first prefix to ensure the maximum amount of contention. Fixes: cockroachdb#154590 Release note: None
1 parent a869411 commit 81f0490

File tree

1 file changed

+202
-1
lines changed

1 file changed

+202
-1
lines changed

pkg/cmd/roachtest/tests/vecindex.go

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"encoding/binary"
1111
"fmt"
12+
"math"
1213
"math/rand"
1314
"strings"
1415
"sync/atomic"
@@ -45,6 +46,7 @@ type vecIndexOptions struct {
4546
preBatchSz int // Insert batch size pre-index creation
4647
beamSizes []int // Beamsizes to verify with
4748
minRecall []float64 // Minimum recall@10 threshold
49+
rwSplit float64 // Percentage of concurrent read/write workers that should be readers
4850
}
4951

5052
// makeVecIndexTestName generates test name from configuration
@@ -167,6 +169,7 @@ func registerVectorIndex(r registry.Registry) {
167169
preBatchSz: 100,
168170
beamSizes: []int{8, 16, 32, 64, 128},
169171
minRecall: []float64{0.76, 0.83, 0.88, 0.92, 0.94},
172+
rwSplit: .9,
170173
},
171174
// Local - no prefix
172175
{
@@ -180,6 +183,7 @@ func registerVectorIndex(r registry.Registry) {
180183
preBatchSz: 100,
181184
beamSizes: []int{16, 32, 64},
182185
minRecall: []float64{0.94, 0.95, 0.95},
186+
rwSplit: .9,
183187
},
184188
// Large - no prefix
185189
{
@@ -193,6 +197,7 @@ func registerVectorIndex(r registry.Registry) {
193197
preBatchSz: 100,
194198
beamSizes: []int{8, 16, 32, 64, 128},
195199
minRecall: []float64{0.64, 0.74, 0.81, 0.87, 0.90},
200+
rwSplit: .9,
196201
},
197202
// Standard - with prefix
198203
{
@@ -206,6 +211,7 @@ func registerVectorIndex(r registry.Registry) {
206211
preBatchSz: 100,
207212
beamSizes: []int{8, 16, 32, 64, 128},
208213
minRecall: []float64{0.76, 0.83, 0.88, 0.92, 0.94},
214+
rwSplit: .9,
209215
},
210216
// Local - with prefix
211217
{
@@ -219,6 +225,7 @@ func registerVectorIndex(r registry.Registry) {
219225
preBatchSz: 100,
220226
beamSizes: []int{16, 32, 64},
221227
minRecall: []float64{0.94, 0.95, 0.95},
228+
rwSplit: .9,
222229
},
223230
}
224231

@@ -284,6 +291,9 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
284291

285292
t.L().Printf("Testing recall of loaded data")
286293
testRecall(ctx, t, pool, &loader.Data, &opts, metric)
294+
295+
t.L().Printf("Testing concurrent reads and writes")
296+
testConcurrentReadsAndWrites(ctx, t, pool, &loader.Data, &opts, metric)
287297
}
288298

289299
func testBackfillAndMerge(
@@ -310,7 +320,7 @@ func testBackfillAndMerge(
310320
inserted_phase TEXT NOT NULL,
311321
worker_id INT NOT NULL,
312322
excluded BOOL DEFAULT false,
313-
INDEX (excluded),
323+
INDEX (excluded, worker_id),
314324
PRIMARY KEY (id)
315325
)`, data.Dims))
316326
require.NoError(t, err)
@@ -494,6 +504,197 @@ func testRecall(
494504
}
495505
}
496506

507+
func testConcurrentReadsAndWrites(
508+
ctx context.Context,
509+
t test.Test,
510+
pool *pgxpool.Pool,
511+
data *vecann.Dataset,
512+
opts *vecIndexOptions,
513+
metric vecpb.DistanceMetric,
514+
) {
515+
numWriters := max(int(math.Round((1.0-opts.rwSplit)*float64(opts.workers))), 1)
516+
numReaders := opts.workers - numWriters
517+
518+
t.L().Printf("Running %d write workers and %d read workers for %v", numWriters, numReaders, opts.duration)
519+
timer := time.NewTimer(opts.duration)
520+
done := make(chan struct{})
521+
workers := t.NewGroup()
522+
523+
// Load the first data file in the dataset and use it (only) for the write workers
524+
data.Reset()
525+
hasMore, err := data.Next()
526+
require.NoError(t, err)
527+
require.True(t, hasMore)
528+
rowsPerWriter := data.Train.Count / numWriters
529+
530+
for writer := range numWriters {
531+
workers.Go(func(ctx context.Context, l *logger.Logger) error {
532+
start := writer * rowsPerWriter
533+
writerRows := data.Train.Slice(start, rowsPerWriter)
534+
startPKVal := data.TrainCount + start
535+
536+
conn, err := pool.Acquire(ctx)
537+
require.NoError(t, err)
538+
defer conn.Release()
539+
540+
var rowsWritten int
541+
var nextRowOffset int
542+
var deletingRows bool
543+
for {
544+
select {
545+
case <-done:
546+
var writerState strings.Builder
547+
fmt.Fprintf(&writerState, "Writer %d exiting. Wrote %d rows. Currently ", writer, rowsWritten)
548+
if deletingRows {
549+
writerState.WriteString("deleting.")
550+
} else {
551+
writerState.WriteString("inserting.")
552+
}
553+
l.Printf(writerState.String())
554+
return nil
555+
default:
556+
if !deletingRows {
557+
if err := insertVectors(
558+
ctx,
559+
conn,
560+
writer,
561+
1, /* numCats */
562+
startPKVal+nextRowOffset,
563+
"steady-state",
564+
writerRows.Slice(nextRowOffset, 1 /* count */),
565+
true, /* excluded */
566+
); err != nil {
567+
return err
568+
}
569+
nextRowOffset++
570+
rowsWritten++
571+
} else {
572+
_, err = conn.Exec(
573+
ctx,
574+
"DELETE FROM vecindex_test WHERE excluded = true AND worker_id = $1 LIMIT 10",
575+
writer,
576+
)
577+
var pgErr *pgconn.PgError
578+
if err != nil && errors.As(err, &pgErr) {
579+
switch pgErr.Code {
580+
case "40001", "40P01":
581+
continue
582+
}
583+
}
584+
require.NoError(t, err)
585+
586+
nextRowOffset += 10
587+
}
588+
if nextRowOffset >= writerRows.Count {
589+
nextRowOffset = 0
590+
deletingRows = !deletingRows
591+
}
592+
}
593+
}
594+
})
595+
}
596+
597+
maxResults := 10
598+
operator := getOperatorForMetric(metric)
599+
var queryBuilder strings.Builder
600+
queryBuilder.WriteString("SELECT id FROM (SELECT excluded, id, embedding FROM vecindex_test@vecidx")
601+
if opts.prefixCount > 0 {
602+
// For multi-prefix tests, we only run load against category 0. This simplifies the test code and
603+
// maximizes contention between readers and writers.
604+
queryBuilder.WriteString(" WHERE category = 0")
605+
}
606+
// Fetch enough results that we should see most if not all of the canonical rows even if the writer
607+
// workers have completely filled in additional copies of the dataset.
608+
fmt.Fprintf(&queryBuilder, " ORDER BY embedding %s $1 LIMIT %d)", operator, maxResults*(numWriters+1))
609+
queryBuilder.WriteString(" WHERE NOT excluded") // Only look at canonical rows
610+
fmt.Fprintf(&queryBuilder, " ORDER BY embedding %s $1 LIMIT %d", operator, maxResults)
611+
searchSQL := queryBuilder.String()
612+
613+
for reader := range numReaders {
614+
workers.Go(func(ctx context.Context, l *logger.Logger) error {
615+
results := make([]cspann.KeyBytes, maxResults)
616+
primaryKeys := make([]byte, maxResults*6)
617+
truth := make([]cspann.KeyBytes, maxResults)
618+
619+
conn, err := pool.Acquire(ctx)
620+
if err != nil {
621+
return err
622+
}
623+
defer conn.Release()
624+
625+
beamIdx := rand.Intn(len(opts.beamSizes))
626+
minRecall := opts.minRecall[beamIdx]
627+
628+
var sumRecall float64
629+
var searches int
630+
631+
_, err = conn.Exec(ctx, fmt.Sprintf("SET vector_search_beam_size = %d", opts.beamSizes[beamIdx]))
632+
if err != nil {
633+
return err
634+
}
635+
636+
for {
637+
select {
638+
case <-done:
639+
avgRecall := sumRecall / float64(searches)
640+
if avgRecall < minRecall {
641+
return errors.AssertionFailedf(
642+
"Average recall (%f) is less than minimum (%f) for worker %d with beam size %d",
643+
avgRecall,
644+
minRecall,
645+
reader,
646+
opts.beamSizes[beamIdx],
647+
)
648+
}
649+
l.Printf(
650+
"Reader %d exiting with average recall of %.2f over %d searches with beam size %d",
651+
reader,
652+
avgRecall*100,
653+
searches,
654+
opts.beamSizes[beamIdx],
655+
)
656+
return nil
657+
default:
658+
queryIdx := rand.Intn(data.Test.Count)
659+
queryVec := data.Test.At(queryIdx)
660+
661+
rows, err := conn.Query(ctx, searchSQL, queryVec)
662+
if err != nil {
663+
return err
664+
}
665+
666+
results = results[:0]
667+
for rows.Next() {
668+
var id []byte
669+
err = rows.Scan(&id)
670+
require.NoError(t, err)
671+
results = append(results, id)
672+
}
673+
if err = rows.Err(); err != nil {
674+
return err
675+
}
676+
677+
primaryKeys = primaryKeys[:0]
678+
truth = truth[:0]
679+
for n := range maxResults {
680+
primaryKeys = appendCanonicalKey(primaryKeys, 0, int(data.Neighbors[queryIdx][n]))
681+
truth = append(truth, primaryKeys[len(primaryKeys)-6:])
682+
}
683+
684+
sumRecall += vecann.CalculateRecall(results, truth)
685+
searches++
686+
}
687+
}
688+
})
689+
}
690+
691+
endTime := <-timer.C
692+
t.L().Printf("Shutting down workers at %v", endTime)
693+
close(done)
694+
695+
workers.Wait()
696+
}
697+
497698
func insertVectors(
498699
ctx context.Context,
499700
conn *pgxpool.Conn,

0 commit comments

Comments
 (0)