Skip to content

Commit 3462dac

Browse files
committed
vecbench: update vecbench and vecann to handle segmented datasets
A recent change updated vector datasets to segment huge sets of training vectors across multiple files. This commit updates vecbench and vecann to consume these multi-file datasets. Vecbench will separately read each file and insert its vectors into the store, so that memory usage stays low. Vecann will load all files into memory so that it can randomly choose vectors to insert and search. Epic: CRDB-42943 Release note: None Release note (cli change):
1 parent ac400e7 commit 3462dac

File tree

3 files changed

+279
-183
lines changed

3 files changed

+279
-183
lines changed

pkg/cmd/vecbench/main.go

Lines changed: 93 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ type vectorBench struct {
188188
datasetName string
189189
distanceMetric vecpb.DistanceMetric
190190
provider VectorProvider
191-
buildData vecann.Dataset
192-
searchData vecann.SearchDataset
191+
data vecann.Dataset
193192
}
194193

195194
// newVectorBench creates a new VecBench instance for the given dataset.
@@ -218,11 +217,11 @@ func newVectorBench(ctx context.Context, stopper *stop.Stopper, datasetName stri
218217
// SearchIndex downloads, builds, and searches an index for the given dataset.
219218
func (vb *vectorBench) SearchIndex() {
220219
// Ensure the data needed for the search is available.
221-
vb.ensureDataset(vb.ctx, true /* forSearch */)
220+
vb.ensureDataset(vb.ctx)
222221

223222
var err error
224223
vb.provider, err = newVectorProvider(
225-
vb.stopper, vb.datasetName, vb.searchData.Test.Dims, vb.distanceMetric)
224+
vb.stopper, vb.datasetName, vb.data.Dims, vb.distanceMetric)
226225
if err != nil {
227226
panic(err)
228227
}
@@ -249,10 +248,10 @@ func (vb *vectorBench) SearchIndex() {
249248

250249
// Search for test vectors.
251250
var sumRecall, sumVectors, sumLeafVectors, sumFullVectors, sumPartitions float64
252-
count := vb.searchData.Test.Count
251+
count := vb.data.Test.Count
253252
for i := range count {
254253
// Calculate truth set for the vector.
255-
queryVector := vb.searchData.Test.At(i)
254+
queryVector := vb.data.Test.At(i)
256255

257256
var stats cspann.SearchStats
258257
prediction, err := vb.provider.Search(vb.ctx, state, queryVector, &stats)
@@ -264,7 +263,7 @@ func (vb *vectorBench) SearchIndex() {
264263
truth := make([]cspann.KeyBytes, maxResults)
265264
for neighbor := range maxResults {
266265
primaryKey := primaryKeys[neighbor*4 : neighbor*4+4]
267-
binary.BigEndian.PutUint32(primaryKey, uint32(vb.searchData.Neighbors[i][neighbor]))
266+
binary.BigEndian.PutUint32(primaryKey, uint32(vb.data.Neighbors[i][neighbor]))
268267
truth[neighbor] = primaryKey
269268
}
270269

@@ -287,7 +286,7 @@ func (vb *vectorBench) SearchIndex() {
287286
fmt.Printf(White+"%s\n"+Reset, vb.datasetName)
288287
fmt.Printf(
289288
White+"%d train vectors, %d test vectors, %d dimensions, %d/%d min/max partitions, build beam size %d\n"+Reset,
290-
vb.searchData.Count, vb.searchData.Test.Count, vb.searchData.Test.Dims,
289+
vb.data.TrainCount, vb.data.Test.Count, vb.data.Dims,
291290
minPartitionSize, maxPartitionSize, *flagBeamSize)
292291
fmt.Println(vb.provider.FormatStats())
293292

@@ -308,12 +307,12 @@ func (vb *vectorBench) SearchIndex() {
308307
// it is rebuilt from scratch.
309308
func (vb *vectorBench) BuildIndex() {
310309
// Ensure dataset is downloaded and cached.
311-
vb.ensureDataset(vb.ctx, false /* forSearch */)
310+
vb.ensureDataset(vb.ctx)
312311

313312
// Construct the vector provider.
314313
var err error
315314
vb.provider, err = newVectorProvider(
316-
vb.stopper, vb.datasetName, vb.buildData.Train.Dims, vb.distanceMetric)
315+
vb.stopper, vb.datasetName, vb.data.Dims, vb.distanceMetric)
317316
if err != nil {
318317
panic(err)
319318
}
@@ -324,14 +323,6 @@ func (vb *vectorBench) BuildIndex() {
324323
panic(err)
325324
}
326325

327-
// Create unique primary key for each vector.
328-
primaryKeys := make([]cspann.KeyBytes, vb.buildData.Train.Count)
329-
keyBuf := make(cspann.KeyBytes, vb.buildData.Train.Count*4)
330-
for i := range vb.buildData.Train.Count {
331-
primaryKeys[i] = keyBuf[i*4 : i*4+4]
332-
binary.BigEndian.PutUint32(primaryKeys[i], uint32(i))
333-
}
334-
335326
// Compute percentile latencies.
336327
estimator := NewPercentileEstimator(1000)
337328

@@ -355,71 +346,96 @@ func (vb *vectorBench) BuildIndex() {
355346
fmt.Printf(White+"Building index for dataset: %s\n"+Reset, vb.datasetName)
356347
startAt := crtime.NowMono()
357348

358-
// Insert vectors into the provider on multiple goroutines.
349+
// Insert vectors into the provider using batches of training vectors.
359350
var insertCount atomic.Uint64
360-
procs := runtime.GOMAXPROCS(-1)
361-
countPerProc := (vb.buildData.Train.Count + procs) / procs
351+
var lastInserted int
362352
batchSize := *flagBatchSize
363-
for i := 0; i < vb.buildData.Train.Count; i += countPerProc {
364-
end := min(i+countPerProc, vb.buildData.Train.Count)
365-
go func(start, end int) {
366-
// Break vector group into batches that each insert a batch of vectors.
367-
for j := start; j < end; j += batchSize {
368-
startMono := crtime.NowMono()
369-
vectors := vb.buildData.Train.Slice(j, min(j+batchSize, end)-j)
370-
err := vb.provider.InsertVectors(vb.ctx, primaryKeys[j:j+vectors.Count], vectors)
371-
if err != nil {
372-
panic(err)
373-
}
374-
estimator.Add(startMono.Elapsed().Seconds() / float64(vectors.Count))
375-
insertCount.Add(uint64(vectors.Count))
376-
}
377-
}(i, end)
378-
}
379353

380-
// Compute ops per second.
381-
var lastInserted int
354+
// Reset the dataset to start from the beginning
355+
vb.data.Reset()
382356

383-
// Update progress every second.
384-
lastProgressAt := startAt
385357
for {
386-
time.Sleep(time.Second)
387-
388-
// Calculate exactly how long it's been since last progress update.
389-
now := crtime.NowMono()
390-
sinceProgress := now.Sub(lastProgressAt)
391-
lastProgressAt = now
392-
393-
// Calculate ops per second over the last second.
394-
totalInserted := int(insertCount.Load())
395-
opsPerSec := float64(totalInserted-lastInserted) / sinceProgress.Seconds()
396-
lastInserted = totalInserted
397-
398-
cp.AddSample(throughput, opsPerSec)
399-
cp.AddSample(p50, estimator.Estimate(0.50)*1000)
400-
cp.AddSample(p90, estimator.Estimate(0.90)*1000)
401-
cp.AddSample(p99, estimator.Estimate(0.99)*1000)
402-
403-
// Add provider-specific metric samples.
404-
metrics, err := vb.provider.GetMetrics()
358+
// Get next batch of train vectors.
359+
hasMore, err := vb.data.Next()
405360
if err != nil {
406361
panic(err)
407362
}
408-
for i, metric := range metrics {
409-
cp.AddSample(metricIds[i], metric.Value)
363+
if !hasMore {
364+
// No more batches.
365+
break
410366
}
411-
cp.Plot()
412-
413-
if !*flagHideProgress {
414-
sinceStart := now.Sub(startAt)
415-
fmt.Printf(White+"\rInserted %d / %d vectors (%.2f%%) in %v"+Reset,
416-
totalInserted, vb.buildData.Train.Count,
417-
(float64(totalInserted)/float64(vb.buildData.Train.Count))*100,
418-
sinceStart.Truncate(time.Second))
367+
trainBatch := vb.data.Train
368+
insertedBefore := int(insertCount.Load())
369+
370+
// Create primary keys for this batch
371+
primaryKeys := make([]cspann.KeyBytes, trainBatch.Count)
372+
keyBuf := make(cspann.KeyBytes, trainBatch.Count*4)
373+
for i := range trainBatch.Count {
374+
primaryKeys[i] = keyBuf[i*4 : i*4+4]
375+
binary.BigEndian.PutUint32(primaryKeys[i], uint32(insertedBefore+i))
419376
}
420377

421-
if totalInserted >= vb.buildData.Train.Count {
422-
break
378+
procs := runtime.GOMAXPROCS(-1)
379+
countPerProc := (vb.data.Train.Count + procs) / procs
380+
for i := 0; i < vb.data.Train.Count; i += countPerProc {
381+
end := min(i+countPerProc, vb.data.Train.Count)
382+
go func(start, end int) {
383+
// Break vector group into batches that each insert a batch of vectors.
384+
for j := start; j < end; j += batchSize {
385+
startMono := crtime.NowMono()
386+
vectors := vb.data.Train.Slice(j, min(j+batchSize, end)-j)
387+
err := vb.provider.InsertVectors(vb.ctx, primaryKeys[j:j+vectors.Count], vectors)
388+
if err != nil {
389+
panic(err)
390+
}
391+
estimator.Add(startMono.Elapsed().Seconds() / float64(vectors.Count))
392+
insertCount.Add(uint64(vectors.Count))
393+
}
394+
}(i, end)
395+
}
396+
397+
// Update progress every second.
398+
lastProgressAt := startAt
399+
for {
400+
time.Sleep(time.Second)
401+
402+
// Calculate exactly how long it's been since last progress update.
403+
now := crtime.NowMono()
404+
sinceProgress := now.Sub(lastProgressAt)
405+
lastProgressAt = now
406+
407+
// Calculate ops per second over the last second.
408+
totalInserted := int(insertCount.Load())
409+
opsPerSec := float64(totalInserted-lastInserted) / sinceProgress.Seconds()
410+
lastInserted = totalInserted
411+
412+
cp.AddSample(throughput, opsPerSec)
413+
cp.AddSample(p50, estimator.Estimate(0.50)*1000)
414+
cp.AddSample(p90, estimator.Estimate(0.90)*1000)
415+
cp.AddSample(p99, estimator.Estimate(0.99)*1000)
416+
417+
// Add provider-specific metric samples.
418+
metrics, err := vb.provider.GetMetrics()
419+
if err != nil {
420+
panic(err)
421+
}
422+
for i, metric := range metrics {
423+
cp.AddSample(metricIds[i], metric.Value)
424+
}
425+
cp.Plot()
426+
427+
if !*flagHideProgress {
428+
sinceStart := now.Sub(startAt)
429+
fmt.Printf(White+"\rInserted %d / %d vectors (%.2f%%) in %v"+Reset,
430+
totalInserted, vb.data.TrainCount,
431+
(float64(totalInserted)/float64(vb.data.TrainCount))*100,
432+
sinceStart.Truncate(time.Second))
433+
}
434+
435+
// Check if we've inserted all vectors in the batch.
436+
if (totalInserted - insertedBefore) >= vb.data.Train.Count {
437+
break
438+
}
423439
}
424440
}
425441

@@ -432,9 +448,8 @@ func (vb *vectorBench) BuildIndex() {
432448
}
433449

434450
// ensureDataset ensures that the dataset has been downloaded and cached to
435-
// disk. It also loads the data into memory. If "forSearch" is true, then only
436-
// test vectors are loaded, not train vectors.
437-
func (vb *vectorBench) ensureDataset(ctx context.Context, forSearch bool) {
451+
// disk. It also loads the data into memory.
452+
func (vb *vectorBench) ensureDataset(ctx context.Context) {
438453
loader := vecann.DatasetLoader{
439454
DatasetName: vb.datasetName,
440455
OnProgress: func(ctx context.Context, format string, args ...any) {
@@ -450,20 +465,11 @@ func (vb *vectorBench) ensureDataset(ctx context.Context, forSearch bool) {
450465
},
451466
}
452467

453-
if forSearch {
454-
// Only need to load the search data, not the insert data. Loading
455-
// the insert data can take 10+ seconds for large datasets.
456-
if err := loader.LoadForSearch(ctx); err != nil {
457-
panic(err)
458-
}
459-
} else {
460-
if err := loader.Load(ctx); err != nil {
461-
panic(err)
462-
}
468+
if err := loader.Load(ctx); err != nil {
469+
panic(err)
463470
}
464471

465-
vb.buildData = loader.Data
466-
vb.searchData = loader.SearchData
472+
vb.data = loader.Data
467473
}
468474

469475
// newVectorProvider creates a new in-memory or SQL based vector provider that

0 commit comments

Comments
 (0)