@@ -188,8 +188,7 @@ type vectorBench struct {
188
188
datasetName string
189
189
distanceMetric vecpb.DistanceMetric
190
190
provider VectorProvider
191
- buildData vecann.Dataset
192
- searchData vecann.SearchDataset
191
+ data vecann.Dataset
193
192
}
194
193
195
194
// newVectorBench creates a new VecBench instance for the given dataset.
@@ -218,11 +217,11 @@ func newVectorBench(ctx context.Context, stopper *stop.Stopper, datasetName stri
218
217
// SearchIndex downloads, builds, and searches an index for the given dataset.
219
218
func (vb * vectorBench ) SearchIndex () {
220
219
// Ensure the data needed for the search is available.
221
- vb .ensureDataset (vb .ctx , true /* forSearch */ )
220
+ vb .ensureDataset (vb .ctx )
222
221
223
222
var err error
224
223
vb .provider , err = newVectorProvider (
225
- vb .stopper , vb .datasetName , vb .searchData . Test .Dims , vb .distanceMetric )
224
+ vb .stopper , vb .datasetName , vb .data .Dims , vb .distanceMetric )
226
225
if err != nil {
227
226
panic (err )
228
227
}
@@ -249,10 +248,10 @@ func (vb *vectorBench) SearchIndex() {
249
248
250
249
// Search for test vectors.
251
250
var sumRecall , sumVectors , sumLeafVectors , sumFullVectors , sumPartitions float64
252
- count := vb .searchData .Test .Count
251
+ count := vb .data .Test .Count
253
252
for i := range count {
254
253
// Calculate truth set for the vector.
255
- queryVector := vb .searchData .Test .At (i )
254
+ queryVector := vb .data .Test .At (i )
256
255
257
256
var stats cspann.SearchStats
258
257
prediction , err := vb .provider .Search (vb .ctx , state , queryVector , & stats )
@@ -264,7 +263,7 @@ func (vb *vectorBench) SearchIndex() {
264
263
truth := make ([]cspann.KeyBytes , maxResults )
265
264
for neighbor := range maxResults {
266
265
primaryKey := primaryKeys [neighbor * 4 : neighbor * 4 + 4 ]
267
- binary .BigEndian .PutUint32 (primaryKey , uint32 (vb .searchData .Neighbors [i ][neighbor ]))
266
+ binary .BigEndian .PutUint32 (primaryKey , uint32 (vb .data .Neighbors [i ][neighbor ]))
268
267
truth [neighbor ] = primaryKey
269
268
}
270
269
@@ -287,7 +286,7 @@ func (vb *vectorBench) SearchIndex() {
287
286
fmt .Printf (White + "%s\n " + Reset , vb .datasetName )
288
287
fmt .Printf (
289
288
White + "%d train vectors, %d test vectors, %d dimensions, %d/%d min/max partitions, build beam size %d\n " + Reset ,
290
- vb .searchData . Count , vb .searchData .Test .Count , vb .searchData . Test .Dims ,
289
+ vb .data . TrainCount , vb .data .Test .Count , vb .data .Dims ,
291
290
minPartitionSize , maxPartitionSize , * flagBeamSize )
292
291
fmt .Println (vb .provider .FormatStats ())
293
292
@@ -308,12 +307,12 @@ func (vb *vectorBench) SearchIndex() {
308
307
// it is rebuilt from scratch.
309
308
func (vb * vectorBench ) BuildIndex () {
310
309
// Ensure dataset is downloaded and cached.
311
- vb .ensureDataset (vb .ctx , false /* forSearch */ )
310
+ vb .ensureDataset (vb .ctx )
312
311
313
312
// Construct the vector provider.
314
313
var err error
315
314
vb .provider , err = newVectorProvider (
316
- vb .stopper , vb .datasetName , vb .buildData . Train .Dims , vb .distanceMetric )
315
+ vb .stopper , vb .datasetName , vb .data .Dims , vb .distanceMetric )
317
316
if err != nil {
318
317
panic (err )
319
318
}
@@ -324,14 +323,6 @@ func (vb *vectorBench) BuildIndex() {
324
323
panic (err )
325
324
}
326
325
327
- // Create unique primary key for each vector.
328
- primaryKeys := make ([]cspann.KeyBytes , vb .buildData .Train .Count )
329
- keyBuf := make (cspann.KeyBytes , vb .buildData .Train .Count * 4 )
330
- for i := range vb .buildData .Train .Count {
331
- primaryKeys [i ] = keyBuf [i * 4 : i * 4 + 4 ]
332
- binary .BigEndian .PutUint32 (primaryKeys [i ], uint32 (i ))
333
- }
334
-
335
326
// Compute percentile latencies.
336
327
estimator := NewPercentileEstimator (1000 )
337
328
@@ -355,71 +346,96 @@ func (vb *vectorBench) BuildIndex() {
355
346
fmt .Printf (White + "Building index for dataset: %s\n " + Reset , vb .datasetName )
356
347
startAt := crtime .NowMono ()
357
348
358
- // Insert vectors into the provider on multiple goroutines .
349
+ // Insert vectors into the provider using batches of training vectors .
359
350
var insertCount atomic.Uint64
360
- procs := runtime .GOMAXPROCS (- 1 )
361
- countPerProc := (vb .buildData .Train .Count + procs ) / procs
351
+ var lastInserted int
362
352
batchSize := * flagBatchSize
363
- for i := 0 ; i < vb .buildData .Train .Count ; i += countPerProc {
364
- end := min (i + countPerProc , vb .buildData .Train .Count )
365
- go func (start , end int ) {
366
- // Break vector group into batches that each insert a batch of vectors.
367
- for j := start ; j < end ; j += batchSize {
368
- startMono := crtime .NowMono ()
369
- vectors := vb .buildData .Train .Slice (j , min (j + batchSize , end )- j )
370
- err := vb .provider .InsertVectors (vb .ctx , primaryKeys [j :j + vectors .Count ], vectors )
371
- if err != nil {
372
- panic (err )
373
- }
374
- estimator .Add (startMono .Elapsed ().Seconds () / float64 (vectors .Count ))
375
- insertCount .Add (uint64 (vectors .Count ))
376
- }
377
- }(i , end )
378
- }
379
353
380
- // Compute ops per second.
381
- var lastInserted int
354
+ // Reset the dataset to start from the beginning
355
+ vb . data . Reset ()
382
356
383
- // Update progress every second.
384
- lastProgressAt := startAt
385
357
for {
386
- time .Sleep (time .Second )
387
-
388
- // Calculate exactly how long it's been since last progress update.
389
- now := crtime .NowMono ()
390
- sinceProgress := now .Sub (lastProgressAt )
391
- lastProgressAt = now
392
-
393
- // Calculate ops per second over the last second.
394
- totalInserted := int (insertCount .Load ())
395
- opsPerSec := float64 (totalInserted - lastInserted ) / sinceProgress .Seconds ()
396
- lastInserted = totalInserted
397
-
398
- cp .AddSample (throughput , opsPerSec )
399
- cp .AddSample (p50 , estimator .Estimate (0.50 )* 1000 )
400
- cp .AddSample (p90 , estimator .Estimate (0.90 )* 1000 )
401
- cp .AddSample (p99 , estimator .Estimate (0.99 )* 1000 )
402
-
403
- // Add provider-specific metric samples.
404
- metrics , err := vb .provider .GetMetrics ()
358
+ // Get next batch of train vectors.
359
+ hasMore , err := vb .data .Next ()
405
360
if err != nil {
406
361
panic (err )
407
362
}
408
- for i , metric := range metrics {
409
- cp .AddSample (metricIds [i ], metric .Value )
363
+ if ! hasMore {
364
+ // No more batches.
365
+ break
410
366
}
411
- cp .Plot ()
412
-
413
- if ! * flagHideProgress {
414
- sinceStart := now .Sub (startAt )
415
- fmt .Printf (White + "\r Inserted %d / %d vectors (%.2f%%) in %v" + Reset ,
416
- totalInserted , vb .buildData .Train .Count ,
417
- (float64 (totalInserted )/ float64 (vb .buildData .Train .Count ))* 100 ,
418
- sinceStart .Truncate (time .Second ))
367
+ trainBatch := vb .data .Train
368
+ insertedBefore := int (insertCount .Load ())
369
+
370
+ // Create primary keys for this batch
371
+ primaryKeys := make ([]cspann.KeyBytes , trainBatch .Count )
372
+ keyBuf := make (cspann.KeyBytes , trainBatch .Count * 4 )
373
+ for i := range trainBatch .Count {
374
+ primaryKeys [i ] = keyBuf [i * 4 : i * 4 + 4 ]
375
+ binary .BigEndian .PutUint32 (primaryKeys [i ], uint32 (insertedBefore + i ))
419
376
}
420
377
421
- if totalInserted >= vb .buildData .Train .Count {
422
- break
378
+ procs := runtime .GOMAXPROCS (- 1 )
379
+ countPerProc := (vb .data .Train .Count + procs ) / procs
380
+ for i := 0 ; i < vb .data .Train .Count ; i += countPerProc {
381
+ end := min (i + countPerProc , vb .data .Train .Count )
382
+ go func (start , end int ) {
383
+ // Break vector group into batches that each insert a batch of vectors.
384
+ for j := start ; j < end ; j += batchSize {
385
+ startMono := crtime .NowMono ()
386
+ vectors := vb .data .Train .Slice (j , min (j + batchSize , end )- j )
387
+ err := vb .provider .InsertVectors (vb .ctx , primaryKeys [j :j + vectors .Count ], vectors )
388
+ if err != nil {
389
+ panic (err )
390
+ }
391
+ estimator .Add (startMono .Elapsed ().Seconds () / float64 (vectors .Count ))
392
+ insertCount .Add (uint64 (vectors .Count ))
393
+ }
394
+ }(i , end )
395
+ }
396
+
397
+ // Update progress every second.
398
+ lastProgressAt := startAt
399
+ for {
400
+ time .Sleep (time .Second )
401
+
402
+ // Calculate exactly how long it's been since last progress update.
403
+ now := crtime .NowMono ()
404
+ sinceProgress := now .Sub (lastProgressAt )
405
+ lastProgressAt = now
406
+
407
+ // Calculate ops per second over the last second.
408
+ totalInserted := int (insertCount .Load ())
409
+ opsPerSec := float64 (totalInserted - lastInserted ) / sinceProgress .Seconds ()
410
+ lastInserted = totalInserted
411
+
412
+ cp .AddSample (throughput , opsPerSec )
413
+ cp .AddSample (p50 , estimator .Estimate (0.50 )* 1000 )
414
+ cp .AddSample (p90 , estimator .Estimate (0.90 )* 1000 )
415
+ cp .AddSample (p99 , estimator .Estimate (0.99 )* 1000 )
416
+
417
+ // Add provider-specific metric samples.
418
+ metrics , err := vb .provider .GetMetrics ()
419
+ if err != nil {
420
+ panic (err )
421
+ }
422
+ for i , metric := range metrics {
423
+ cp .AddSample (metricIds [i ], metric .Value )
424
+ }
425
+ cp .Plot ()
426
+
427
+ if ! * flagHideProgress {
428
+ sinceStart := now .Sub (startAt )
429
+ fmt .Printf (White + "\r Inserted %d / %d vectors (%.2f%%) in %v" + Reset ,
430
+ totalInserted , vb .data .TrainCount ,
431
+ (float64 (totalInserted )/ float64 (vb .data .TrainCount ))* 100 ,
432
+ sinceStart .Truncate (time .Second ))
433
+ }
434
+
435
+ // Check if we've inserted all vectors in the batch.
436
+ if (totalInserted - insertedBefore ) >= vb .data .Train .Count {
437
+ break
438
+ }
423
439
}
424
440
}
425
441
@@ -432,9 +448,8 @@ func (vb *vectorBench) BuildIndex() {
432
448
}
433
449
434
450
// ensureDataset ensures that the dataset has been downloaded and cached to
435
- // disk. It also loads the data into memory. If "forSearch" is true, then only
436
- // test vectors are loaded, not train vectors.
437
- func (vb * vectorBench ) ensureDataset (ctx context.Context , forSearch bool ) {
451
+ // disk. It also loads the data into memory.
452
+ func (vb * vectorBench ) ensureDataset (ctx context.Context ) {
438
453
loader := vecann.DatasetLoader {
439
454
DatasetName : vb .datasetName ,
440
455
OnProgress : func (ctx context.Context , format string , args ... any ) {
@@ -450,20 +465,11 @@ func (vb *vectorBench) ensureDataset(ctx context.Context, forSearch bool) {
450
465
},
451
466
}
452
467
453
- if forSearch {
454
- // Only need to load the search data, not the insert data. Loading
455
- // the insert data can take 10+ seconds for large datasets.
456
- if err := loader .LoadForSearch (ctx ); err != nil {
457
- panic (err )
458
- }
459
- } else {
460
- if err := loader .Load (ctx ); err != nil {
461
- panic (err )
462
- }
468
+ if err := loader .Load (ctx ); err != nil {
469
+ panic (err )
463
470
}
464
471
465
- vb .buildData = loader .Data
466
- vb .searchData = loader .SearchData
472
+ vb .data = loader .Data
467
473
}
468
474
469
475
// newVectorProvider creates a new in-memory or SQL based vector provider that
0 commit comments