Skip to content

Commit d312aa7

Browse files
Improve the computation of accuracy (#408)
* Add methods to compute n-recall@k. Change the ground truth from List<Set<Integer>> to List<List<Integer>> to preserve the order * Include topKGrid in Bench * Add averagePrecisionAtK and meanAveragePrecisionAtK * Decouple runtime measurement from the computation of accuracy and the gathering of other telemetry metrics
1 parent a6b004d commit d312aa7

File tree

11 files changed

+214
-101
lines changed

11 files changed

+214
-101
lines changed

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
import org.slf4j.LoggerFactory;
2929

3030
import java.io.IOException;
31-
import java.util.ArrayList;
32-
import java.util.Set;
31+
import java.util.List;
3332
import java.util.concurrent.TimeUnit;
3433

3534
@BenchmarkMode(Mode.AverageTime)
@@ -42,9 +41,9 @@
4241
public class IndexConstructionWithStaticSetBenchmark {
4342
private static final Logger log = LoggerFactory.getLogger(IndexConstructionWithStaticSetBenchmark.class);
4443
private RandomAccessVectorValues ravv;
45-
private ArrayList<VectorFloat<?>> baseVectors;
46-
private ArrayList<VectorFloat<?>> queryVectors;
47-
private ArrayList<Set<Integer>> groundTruth;
44+
private List<VectorFloat<?>> baseVectors;
45+
private List<VectorFloat<?>> queryVectors;
46+
private List<List<Integer>> groundTruth;
4847
private BuildScoreProvider bsp;
4948
@Param({"16", "32", "64"})
5049
private int M; // graph degree

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQBenchmark.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
import org.slf4j.LoggerFactory;
3030

3131
import java.io.IOException;
32-
import java.util.ArrayList;
33-
import java.util.Set;
32+
import java.util.List;
3433
import java.util.concurrent.TimeUnit;
3534

3635
@BenchmarkMode(Mode.AverageTime)
@@ -43,9 +42,9 @@
4342
public class PQBenchmark {
4443
private static final Logger log = LoggerFactory.getLogger(PQBenchmark.class);
4544
private RandomAccessVectorValues ravv;
46-
private ArrayList<VectorFloat<?>> baseVectors;
47-
private ArrayList<VectorFloat<?>> queryVectors;
48-
private ArrayList<Set<Integer>> groundTruth;
45+
private List<VectorFloat<?>> baseVectors;
46+
private List<VectorFloat<?>> queryVectors;
47+
private List<List<Integer>> groundTruth;
4948
@Param({"16", "32", "64"})
5049
private int M; // Number of subspaces
5150
int originalDimension;

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
import org.slf4j.LoggerFactory;
2929

3030
import java.io.IOException;
31-
import java.util.ArrayList;
32-
import java.util.Set;
31+
import java.util.List;
3332
import java.util.concurrent.TimeUnit;
3433

3534
@BenchmarkMode(Mode.AverageTime)
@@ -42,9 +41,9 @@
4241
public class StaticSetVectorsBenchmark {
4342
private static final Logger log = LoggerFactory.getLogger(StaticSetVectorsBenchmark.class);
4443
private RandomAccessVectorValues ravv;
45-
private ArrayList<VectorFloat<?>> baseVectors;
46-
private ArrayList<VectorFloat<?>> queryVectors;
47-
private ArrayList<Set<Integer>> groundTruth;
44+
private List<VectorFloat<?>> baseVectors;
45+
private List<VectorFloat<?>> queryVectors;
46+
private List<List<Integer>> groundTruth;
4847
private GraphIndexBuilder graphIndexBuilder;
4948
private GraphIndex graphIndex;
5049
int originalDimension;

jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public static void main(String[] args) throws IOException {
4444

4545
var mGrid = List.of(32); // List.of(16, 24, 32, 48, 64, 96, 128);
4646
var efConstructionGrid = List.of(100); // List.of(60, 80, 100, 120, 160, 200, 400, 600, 800);
47+
var topKGrid = List.of(10, 100);
4748
var overqueryGrid = List.of(1.0, 2.0, 5.0); // rerankK = oq * topK
4849
var neighborOverflowGrid = List.of(1.2f); // List.of(1.2f, 2.0f);
4950
var addHierarchyGrid = List.of(true); // List.of(false, true);
@@ -77,15 +78,15 @@ public static void main(String[] args) throws IOException {
7778
"nv-qa-v4-100k",
7879
"colbert-1M",
7980
"gecko-100k");
80-
executeNw(coreFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, overqueryGrid, usePruningGrid);
81+
executeNw(coreFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, topKGrid, overqueryGrid, usePruningGrid);
8182

8283
var extraFiles = List.of(
8384
"openai-v3-large-3072-100k",
8485
"openai-v3-large-1536-100k",
8586
"e5-small-v2-100k",
8687
"e5-base-v2-100k",
8788
"e5-large-v2-100k");
88-
executeNw(extraFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, overqueryGrid, usePruningGrid);
89+
executeNw(extraFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, topKGrid, overqueryGrid, usePruningGrid);
8990

9091
// smaller vectors from ann-benchmarks
9192
var hdf5Files = List.of(
@@ -102,7 +103,7 @@ public static void main(String[] args) throws IOException {
102103
for (var f : hdf5Files) {
103104
if (pattern.matcher(f).find()) {
104105
DownloadHelper.maybeDownloadHdf5(f);
105-
Grid.runAll(Hdf5Loader.load(f), mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, searchCompression, overqueryGrid, usePruningGrid);
106+
Grid.runAll(Hdf5Loader.load(f), mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, searchCompression, topKGrid, overqueryGrid, usePruningGrid);
106107
}
107108
}
108109

@@ -112,15 +113,15 @@ public static void main(String[] args) throws IOException {
112113
ds -> new PQParameters(ds.getDimension(), 256, true, UNWEIGHTED));
113114
buildCompression = Arrays.asList(__ -> CompressorParameters.NONE);
114115
var grid2d = DataSetCreator.create2DGrid(4_000_000, 10_000, 100);
115-
Grid.runAll(grid2d, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, searchCompression, overqueryGrid, usePruningGrid);
116+
Grid.runAll(grid2d, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, searchCompression, topKGrid, overqueryGrid, usePruningGrid);
116117
}
117118
}
118119

119-
private static void executeNw(List<String> coreFiles, Pattern pattern, List<Function<DataSet, CompressorParameters>> buildCompression, List<EnumSet<FeatureId>> featureSets, List<Function<DataSet, CompressorParameters>> compressionGrid, List<Integer> mGrid, List<Integer> efConstructionGrid, List<Float> neighborOverflowGrid, List<Boolean> addHierarchyGrid, List<Double> efSearchGrid, List<Boolean> usePruningGrid) throws IOException {
120+
private static void executeNw(List<String> coreFiles, Pattern pattern, List<Function<DataSet, CompressorParameters>> buildCompression, List<EnumSet<FeatureId>> featureSets, List<Function<DataSet, CompressorParameters>> compressionGrid, List<Integer> mGrid, List<Integer> efConstructionGrid, List<Float> neighborOverflowGrid, List<Boolean> addHierarchyGrid, List<Integer> topKGrid, List<Double> efSearchGrid, List<Boolean> usePruningGrid) throws IOException {
120121
for (var nwDatasetName : coreFiles) {
121122
if (pattern.matcher(nwDatasetName).find()) {
122123
var mfd = DownloadHelper.maybeDownloadFvecs(nwDatasetName);
123-
Grid.runAll(mfd.load(), mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, compressionGrid, efSearchGrid, usePruningGrid);
124+
Grid.runAll(mfd.load(), mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, compressionGrid, topKGrid, efSearchGrid, usePruningGrid);
124125
}
125126
}
126127
}

jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package io.github.jbellis.jvector.example;
1818

1919
import io.github.jbellis.jvector.disk.ReaderSupplierFactory;
20+
import io.github.jbellis.jvector.example.util.AccuracyMetrics;
2021
import io.github.jbellis.jvector.example.util.CompressorParameters;
2122
import io.github.jbellis.jvector.example.util.DataSet;
2223
import io.github.jbellis.jvector.graph.GraphIndex;
@@ -55,7 +56,6 @@
5556
import java.nio.file.Files;
5657
import java.nio.file.Path;
5758
import java.nio.file.Paths;
58-
import java.util.Arrays;
5959
import java.util.EnumMap;
6060
import java.util.HashMap;
6161
import java.util.IdentityHashMap;
@@ -86,6 +86,7 @@ static void runAll(DataSet ds,
8686
List<? extends Set<FeatureId>> featureSets,
8787
List<Function<DataSet, CompressorParameters>> buildCompressors,
8888
List<Function<DataSet, CompressorParameters>> compressionGrid,
89+
List<Integer> topKGrid,
8990
List<Double> efSearchFactor,
9091
List<Boolean> usePruningGrid) throws IOException
9192
{
@@ -97,7 +98,7 @@ static void runAll(DataSet ds,
9798
for (int efC : efConstructionGrid) {
9899
for (var bc : buildCompressors) {
99100
var compressor = getCompressor(bc, ds);
100-
runOneGraph(featureSets, M, efC, neighborOverflow, addHierarchy, compressor, compressionGrid, efSearchFactor, usePruningGrid, ds, testDirectory);
101+
runOneGraph(featureSets, M, efC, neighborOverflow, addHierarchy, compressor, compressionGrid, topKGrid, efSearchFactor, usePruningGrid, ds, testDirectory);
101102
}
102103
}
103104
}
@@ -123,6 +124,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
123124
boolean addHierarchy,
124125
VectorCompressor<?> buildCompressor,
125126
List<Function<DataSet, CompressorParameters>> compressionGrid,
127+
List<Integer> topKGrid,
126128
List<Double> efSearchOptions,
127129
List<Boolean> usePruningGrid,
128130
DataSet ds,
@@ -151,7 +153,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
151153
indexes.forEach((features, index) -> {
152154
try (var cs = new ConfiguredSystem(ds, index, cv,
153155
index instanceof OnDiskGraphIndex ? ((OnDiskGraphIndex) index).getFeatureSet() : Set.of())) {
154-
testConfiguration(cs, efSearchOptions, usePruningGrid);
156+
testConfiguration(cs, topKGrid, efSearchOptions, usePruningGrid);
155157
} catch (Exception e) {
156158
throw new RuntimeException(e);
157159
}
@@ -361,22 +363,20 @@ private static Map<Set<FeatureId>, GraphIndex> buildInMemory(List<? extends Set<
361363
// avoid recomputing the compressor repeatedly (this is a relatively small memory footprint)
362364
static final Map<String, VectorCompressor<?>> cachedCompressors = new IdentityHashMap<>();
363365

364-
private static void testConfiguration(ConfiguredSystem cs, List<Double> efSearchOptions, List<Boolean> usePruningGrid) {
365-
var topK = cs.ds.groundTruth.get(0).size();
366+
private static void testConfiguration(ConfiguredSystem cs, List<Integer> topKGrid, List<Double> efSearchOptions, List<Boolean> usePruningGrid) {
366367
int queryRuns = 2;
367368
System.out.format("Using %s:%n", cs.index);
368-
for (var overquery : efSearchOptions) {
369-
int rerankK = (int) (topK * overquery);
370-
for (var usePruning : usePruningGrid) {
371-
var startTime = System.nanoTime();
372-
var pqr = performQueries(cs, topK, rerankK, usePruning, queryRuns);
373-
var stopTime = System.nanoTime();
374-
var recall = ((double) pqr.topKFound) / (queryRuns * cs.ds.queryVectors.size() * topK);
375-
System.out.format(" Query top %d/%d recall %.4f in %.2fms after %.2f nodes visited (AVG) and %.2f nodes expanded with pruning=%b%n",
376-
topK, rerankK, recall, (stopTime - startTime) / (queryRuns * 1_000_000.0),
377-
(double) pqr.nodesVisited / (queryRuns * cs.ds.queryVectors.size()),
378-
(double) pqr.nodesExpanded / (queryRuns * cs.ds.queryVectors.size()),
379-
usePruning);
369+
for (var topK : topKGrid) {
370+
for (var overquery : efSearchOptions) {
371+
int rerankK = (int) (topK * overquery);
372+
for (var usePruning : usePruningGrid) {
373+
var pqr = performQueries(cs, topK, rerankK, usePruning, queryRuns);
374+
System.out.format(" Query top %d/%d recall %.4f in %.2fms after %.2f nodes visited (AVG) and %.2f nodes expanded with pruning=%b%n",
375+
topK, rerankK, pqr.recall, pqr.runtime / (1_000_000.0),
376+
(double) pqr.nodesVisited / cs.ds.queryVectors.size(),
377+
(double) pqr.nodesExpanded / cs.ds.queryVectors.size(),
378+
usePruning);
379+
}
380380
}
381381
}
382382
}
@@ -421,44 +421,36 @@ private static VectorCompressor<?> getCompressor(Function<DataSet, CompressorPar
421421
});
422422
}
423423

424-
private static long topKCorrect(int topK, int[] resultNodes, Set<Integer> gt) {
425-
int count = Math.min(resultNodes.length, topK);
426-
var resultSet = Arrays.stream(resultNodes, 0, count)
427-
.boxed()
428-
.collect(Collectors.toSet());
429-
assert resultSet.size() == count : String.format("%s duplicate results out of %s", count - resultSet.size(), count);
430-
return resultSet.stream().filter(gt::contains).count();
431-
}
432-
433-
private static long topKCorrect(int topK, SearchResult.NodeScore[] nn, Set<Integer> gt) {
434-
var a = Arrays.stream(nn).mapToInt(nodeScore -> nodeScore.node).toArray();
435-
return topKCorrect(topK, a, gt);
436-
}
437-
438424
private static ResultSummary performQueries(ConfiguredSystem cs, int topK, int rerankK, boolean usePruning, int queryRuns) {
439-
LongAdder topKfound = new LongAdder();
440425
LongAdder nodesVisited = new LongAdder();
441426
LongAdder nodesExpanded = new LongAdder();
442427
LongAdder nodesExpandedBaseLayer = new LongAdder();
428+
LongAdder runtime = new LongAdder();
429+
double recall = 0;
430+
443431
for (int k = 0; k < queryRuns; k++) {
444-
IntStream.range(0, cs.ds.queryVectors.size()).parallel().forEach(i -> {
432+
var startTime = System.nanoTime();
433+
List<SearchResult> listSR = IntStream.range(0, cs.ds.queryVectors.size()).parallel().mapToObj(i -> {
445434
var queryVector = cs.ds.queryVectors.get(i);
446435
SearchResult sr;
447436
var searcher = cs.getSearcher();
448437
searcher.usePruning(usePruning);
449438
var sf = cs.scoreProviderFor(queryVector, searcher.getView());
450439
sr = searcher.search(sf, topK, rerankK, 0.0f, 0.0f, Bits.ALL);
440+
return sr;
441+
}).collect(Collectors.toList());
442+
var stopTime = System.nanoTime();
451443

452-
// process search result
453-
var gt = cs.ds.groundTruth.get(i);
454-
var n = topKCorrect(topK, sr.getNodes(), gt);
455-
topKfound.add(n);
444+
runtime.add(stopTime - startTime);
445+
// process search result
446+
listSR.stream().parallel().forEach(sr -> {
456447
nodesVisited.add(sr.getVisitedCount());
457448
nodesExpanded.add(sr.getExpandedCount());
458449
nodesExpandedBaseLayer.add(sr.getExpandedCountBaseLayer());
459450
});
451+
recall += AccuracyMetrics.recallFromSearchResults(cs.ds.groundTruth, listSR, topK, topK);
460452
}
461-
return new ResultSummary((int) topKfound.sum(), nodesVisited.sum(), nodesExpanded.sum(), nodesExpandedBaseLayer.sum());
453+
return new ResultSummary(recall / queryRuns, nodesVisited.sum() / queryRuns, nodesExpanded.sum() / queryRuns, nodesExpandedBaseLayer.sum() / queryRuns, runtime.sum() / queryRuns);
462454
}
463455

464456
static class ConfiguredSystem implements AutoCloseable {
@@ -506,16 +498,18 @@ public void close() throws Exception {
506498
}
507499

508500
static class ResultSummary {
509-
final int topKFound;
501+
final double recall;
510502
final long nodesVisited;
511503
final long nodesExpanded;
512504
final long nodesExpandedBaseLayer;
505+
final long runtime;
513506

514-
ResultSummary(int topKFound, long nodesVisited, long nodesExpanded, long nodesExpandedBaseLayer) {
515-
this.topKFound = topKFound;
507+
ResultSummary(double recall, long nodesVisited, long nodesExpanded, long nodesExpandedBaseLayer, long runtime) {
508+
this.recall = recall;
516509
this.nodesVisited = nodesVisited;
517510
this.nodesExpanded = nodesExpanded;
518511
this.nodesExpandedBaseLayer = nodesExpandedBaseLayer;
512+
this.runtime = runtime;
519513
}
520514
}
521515
}

0 commit comments

Comments
 (0)