99 "context"
1010 "encoding/binary"
1111 "fmt"
12+ "math"
1213 "math/rand"
1314 "strings"
1415 "sync/atomic"
@@ -45,6 +46,7 @@ type vecIndexOptions struct {
4546 preBatchSz int // Insert batch size pre-index creation
4647 beamSizes []int // Beamsizes to verify with
4748 minRecall []float64 // Minimum recall@10 threshold
49+ rwSplit float64 // Percentage of concurrent read/write workers that should be readers
4850}
4951
5052// makeVecIndexTestName generates test name from configuration
@@ -167,6 +169,7 @@ func registerVectorIndex(r registry.Registry) {
167169 preBatchSz : 100 ,
168170 beamSizes : []int {8 , 16 , 32 , 64 , 128 },
169171 minRecall : []float64 {0.76 , 0.83 , 0.88 , 0.92 , 0.94 },
172+ rwSplit : .9 ,
170173 },
171174 // Local - no prefix
172175 {
@@ -180,6 +183,7 @@ func registerVectorIndex(r registry.Registry) {
180183 preBatchSz : 100 ,
181184 beamSizes : []int {16 , 32 , 64 },
182185 minRecall : []float64 {0.94 , 0.95 , 0.95 },
186+ rwSplit : .9 ,
183187 },
184188 // Large - no prefix
185189 {
@@ -193,6 +197,7 @@ func registerVectorIndex(r registry.Registry) {
193197 preBatchSz : 100 ,
194198 beamSizes : []int {8 , 16 , 32 , 64 , 128 },
195199 minRecall : []float64 {0.64 , 0.74 , 0.81 , 0.87 , 0.90 },
200+ rwSplit : .9 ,
196201 },
197202 // Standard - with prefix
198203 {
@@ -206,6 +211,7 @@ func registerVectorIndex(r registry.Registry) {
206211 preBatchSz : 100 ,
207212 beamSizes : []int {8 , 16 , 32 , 64 , 128 },
208213 minRecall : []float64 {0.76 , 0.83 , 0.88 , 0.92 , 0.94 },
214+ rwSplit : .9 ,
209215 },
210216 // Local - with prefix
211217 {
@@ -219,6 +225,7 @@ func registerVectorIndex(r registry.Registry) {
219225 preBatchSz : 100 ,
220226 beamSizes : []int {16 , 32 , 64 },
221227 minRecall : []float64 {0.94 , 0.95 , 0.95 },
228+ rwSplit : .9 ,
222229 },
223230 }
224231
@@ -284,6 +291,9 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
284291
285292 t .L ().Printf ("Testing recall of loaded data" )
286293 testRecall (ctx , t , pool , & loader .Data , & opts , metric )
294+
295+ t .L ().Printf ("Testing concurrent reads and writes" )
296+ testConcurrentReadsAndWrites (ctx , t , pool , & loader .Data , & opts , metric )
287297}
288298
289299func testBackfillAndMerge (
@@ -310,7 +320,7 @@ func testBackfillAndMerge(
310320 inserted_phase TEXT NOT NULL,
311321 worker_id INT NOT NULL,
312322 excluded BOOL DEFAULT false,
313- INDEX (excluded),
323+ INDEX (excluded, worker_id ),
314324 PRIMARY KEY (id)
315325 )` , data .Dims ))
316326 require .NoError (t , err )
@@ -494,6 +504,197 @@ func testRecall(
494504 }
495505}
496506
507+ func testConcurrentReadsAndWrites (
508+ ctx context.Context ,
509+ t test.Test ,
510+ pool * pgxpool.Pool ,
511+ data * vecann.Dataset ,
512+ opts * vecIndexOptions ,
513+ metric vecpb.DistanceMetric ,
514+ ) {
515+ numWriters := max (int (math .Round ((1.0 - opts .rwSplit )* float64 (opts .workers ))), 1 )
516+ numReaders := opts .workers - numWriters
517+
518+ t .L ().Printf ("Running %d write workers and %d read workers for %v" , numWriters , numReaders , opts .duration )
519+ timer := time .NewTimer (opts .duration )
520+ done := make (chan struct {})
521+ workers := t .NewGroup ()
522+
523+ // Load the first data file in the dataset and use it (only) for the write workers
524+ data .Reset ()
525+ hasMore , err := data .Next ()
526+ require .NoError (t , err )
527+ require .True (t , hasMore )
528+ rowsPerWriter := data .Train .Count / numWriters
529+
530+ for writer := range numWriters {
531+ workers .Go (func (ctx context.Context , l * logger.Logger ) error {
532+ start := writer * rowsPerWriter
533+ writerRows := data .Train .Slice (start , rowsPerWriter )
534+ startPKVal := data .TrainCount + start
535+
536+ conn , err := pool .Acquire (ctx )
537+ require .NoError (t , err )
538+ defer conn .Release ()
539+
540+ var rowsWritten int
541+ var nextRowOffset int
542+ var deletingRows bool
543+ for {
544+ select {
545+ case <- done :
546+ var writerState strings.Builder
547+ fmt .Fprintf (& writerState , "Writer %d exiting. Wrote %d rows. Currently " , writer , rowsWritten )
548+ if deletingRows {
549+ writerState .WriteString ("deleting." )
550+ } else {
551+ writerState .WriteString ("inserting." )
552+ }
553+ l .Printf (writerState .String ())
554+ return nil
555+ default :
556+ if ! deletingRows {
557+ if err := insertVectors (
558+ ctx ,
559+ conn ,
560+ writer ,
561+ 1 , /* numCats */
562+ startPKVal + nextRowOffset ,
563+ "steady-state" ,
564+ writerRows .Slice (nextRowOffset , 1 /* count */ ),
565+ true , /* excluded */
566+ ); err != nil {
567+ return err
568+ }
569+ nextRowOffset ++
570+ rowsWritten ++
571+ } else {
572+ _ , err = conn .Exec (
573+ ctx ,
574+ "DELETE FROM vecindex_test WHERE excluded = true AND worker_id = $1 LIMIT 10" ,
575+ writer ,
576+ )
577+ var pgErr * pgconn.PgError
578+ if err != nil && errors .As (err , & pgErr ) {
579+ switch pgErr .Code {
580+ case "40001" , "40P01" :
581+ continue
582+ }
583+ }
584+ require .NoError (t , err )
585+
586+ nextRowOffset += 10
587+ }
588+ if nextRowOffset >= writerRows .Count {
589+ nextRowOffset = 0
590+ deletingRows = ! deletingRows
591+ }
592+ }
593+ }
594+ })
595+ }
596+
597+ maxResults := 10
598+ operator := getOperatorForMetric (metric )
599+ var queryBuilder strings.Builder
600+ queryBuilder .WriteString ("SELECT id FROM (SELECT excluded, id, embedding FROM vecindex_test@vecidx" )
601+ if opts .prefixCount > 0 {
602+ // For multi-prefix tests, we only run load against category 0. This simplifies the test code and
603+ // maximizes contention between readers and writers.
604+ queryBuilder .WriteString (" WHERE category = 0" )
605+ }
606+ // Fetch enough results that we should see most if not all of the canonical rows even if the writer
607+ // workers have completely filled in additional copies of the dataset.
608+ fmt .Fprintf (& queryBuilder , " ORDER BY embedding %s $1 LIMIT %d)" , operator , maxResults * (numWriters + 1 ))
609+ queryBuilder .WriteString (" WHERE NOT excluded" ) // Only look at canonical rows
610+ fmt .Fprintf (& queryBuilder , " ORDER BY embedding %s $1 LIMIT %d" , operator , maxResults )
611+ searchSQL := queryBuilder .String ()
612+
613+ for reader := range numReaders {
614+ workers .Go (func (ctx context.Context , l * logger.Logger ) error {
615+ results := make ([]cspann.KeyBytes , maxResults )
616+ primaryKeys := make ([]byte , maxResults * 6 )
617+ truth := make ([]cspann.KeyBytes , maxResults )
618+
619+ conn , err := pool .Acquire (ctx )
620+ if err != nil {
621+ return err
622+ }
623+ defer conn .Release ()
624+
625+ beamIdx := rand .Intn (len (opts .beamSizes ))
626+ minRecall := opts .minRecall [beamIdx ]
627+
628+ var sumRecall float64
629+ var searches int
630+
631+ _ , err = conn .Exec (ctx , fmt .Sprintf ("SET vector_search_beam_size = %d" , opts .beamSizes [beamIdx ]))
632+ if err != nil {
633+ return err
634+ }
635+
636+ for {
637+ select {
638+ case <- done :
639+ avgRecall := sumRecall / float64 (searches )
640+ if avgRecall < minRecall {
641+ return errors .AssertionFailedf (
642+ "Average recall (%f) is less than minimum (%f) for worker %d with beam size %d" ,
643+ avgRecall ,
644+ minRecall ,
645+ reader ,
646+ opts .beamSizes [beamIdx ],
647+ )
648+ }
649+ l .Printf (
650+ "Reader %d exiting with average recall of %.2f over %d searches with beam size %d" ,
651+ reader ,
652+ avgRecall * 100 ,
653+ searches ,
654+ opts .beamSizes [beamIdx ],
655+ )
656+ return nil
657+ default :
658+ queryIdx := rand .Intn (data .Test .Count )
659+ queryVec := data .Test .At (queryIdx )
660+
661+ rows , err := conn .Query (ctx , searchSQL , queryVec )
662+ if err != nil {
663+ return err
664+ }
665+
666+ results = results [:0 ]
667+ for rows .Next () {
668+ var id []byte
669+ err = rows .Scan (& id )
670+ require .NoError (t , err )
671+ results = append (results , id )
672+ }
673+ if err = rows .Err (); err != nil {
674+ return err
675+ }
676+
677+ primaryKeys = primaryKeys [:0 ]
678+ truth = truth [:0 ]
679+ for n := range maxResults {
680+ primaryKeys = appendCanonicalKey (primaryKeys , 0 , int (data.Neighbors [queryIdx ][n ]))
681+ truth = append (truth , primaryKeys [len (primaryKeys )- 6 :])
682+ }
683+
684+ sumRecall += vecann .CalculateRecall (results , truth )
685+ searches ++
686+ }
687+ }
688+ })
689+ }
690+
691+ endTime := <- timer .C
692+ t .L ().Printf ("Shutting down workers at %v" , endTime )
693+ close (done )
694+
695+ workers .Wait ()
696+ }
697+
497698func insertVectors (
498699 ctx context.Context ,
499700 conn * pgxpool.Conn ,
0 commit comments