Skip to content

Commit 6d590ad

Browse files
Enable specifying the benchmarks in the yaml file (#515)
* Enable specifying the benchmarks in the yaml file * Use the original feature set
1 parent 1c29821 commit 6d590ad

File tree

5 files changed

+141
-35
lines changed

5 files changed

+141
-35
lines changed

jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchYAML.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.github.jbellis.jvector.example.yaml.MultiConfig;
2323

2424
import java.io.IOException;
25+
import java.util.ArrayList;
2526
import java.util.Arrays;
2627
import java.util.List;
2728
import java.util.regex.Pattern;
@@ -46,6 +47,8 @@ public static void main(String[] args) throws IOException {
4647
var datasetCollection = DatasetCollection.load();
4748
var datasetNames = datasetCollection.getAll().stream().filter(dn -> pattern.matcher(dn).find()).collect(Collectors.toList());
4849

50+
List<MultiConfig> allConfigs = new ArrayList<>();
51+
4952
if (!datasetNames.isEmpty()) {
5053
System.out.println("Executing the following datasets: " + datasetNames);
5154

@@ -56,11 +59,7 @@ public static void main(String[] args) throws IOException {
5659
datasetName = datasetName.substring(0, datasetName.length() - ".hdf5".length());
5760
}
5861
MultiConfig config = MultiConfig.getDefaultConfig(datasetName);
59-
60-
Grid.runAll(ds, config.construction.outDegree, config.construction.efConstruction,
61-
config.construction.neighborOverflow, config.construction.addHierarchy, config.construction.refineFinalGraph,
62-
config.construction.getFeatureSets(), config.construction.getCompressorParameters(),
63-
config.search.getCompressorParameters(), config.search.topKOverquery, config.search.useSearchPruning);
62+
allConfigs.add(config);
6463
}
6564
}
6665

@@ -69,16 +68,20 @@ public static void main(String[] args) throws IOException {
6968

7069
if (!configNames.isEmpty()) {
7170
for (var configName : configNames) {
72-
MultiConfig config = MultiConfig.getConfig(configName);
73-
String datasetName = config.dataset;
71+
MultiConfig config = MultiConfig.getDefaultConfig(configName);
72+
allConfigs.add(config);
73+
}
74+
}
7475

75-
DataSet ds = DataSetLoader.loadDataSet(datasetName);
76+
for (var config : allConfigs) {
77+
String datasetName = config.dataset;
7678

77-
Grid.runAll(ds, config.construction.outDegree, config.construction.efConstruction,
78-
config.construction.neighborOverflow, config.construction.addHierarchy, config.construction.refineFinalGraph,
79-
config.construction.getFeatureSets(), config.construction.getCompressorParameters(),
80-
config.search.getCompressorParameters(), config.search.topKOverquery, config.search.useSearchPruning);
81-
}
79+
DataSet ds = DataSetLoader.loadDataSet(datasetName);
80+
81+
Grid.runAll(ds, config.construction.outDegree, config.construction.efConstruction,
82+
config.construction.neighborOverflow, config.construction.addHierarchy, config.construction.refineFinalGraph,
83+
config.construction.getFeatureSets(), config.construction.getCompressorParameters(),
84+
config.search.getCompressorParameters(), config.search.topKOverquery, config.search.useSearchPruning, config.search.benchmarks);
8285
}
8386
}
8487
}

jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java

Lines changed: 111 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717
package io.github.jbellis.jvector.example;
1818

1919
import io.github.jbellis.jvector.disk.ReaderSupplierFactory;
20-
import io.github.jbellis.jvector.example.benchmarks.*;
21-
import io.github.jbellis.jvector.example.util.AccuracyMetrics;
20+
import io.github.jbellis.jvector.example.benchmarks.AccuracyBenchmark;
21+
import io.github.jbellis.jvector.example.benchmarks.BenchmarkTablePrinter;
22+
import io.github.jbellis.jvector.example.benchmarks.CountBenchmark;
23+
import io.github.jbellis.jvector.example.benchmarks.LatencyBenchmark;
24+
import io.github.jbellis.jvector.example.benchmarks.QueryBenchmark;
25+
import io.github.jbellis.jvector.example.benchmarks.QueryTester;
26+
import io.github.jbellis.jvector.example.benchmarks.ThroughputBenchmark;
2227
import io.github.jbellis.jvector.example.util.CompressorParameters;
2328
import io.github.jbellis.jvector.example.util.DataSet;
2429
import io.github.jbellis.jvector.example.util.FilteredForkJoinPool;
@@ -27,7 +32,6 @@
2732
import io.github.jbellis.jvector.graph.GraphSearcher;
2833
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
2934
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
30-
import io.github.jbellis.jvector.graph.SearchResult;
3135
import io.github.jbellis.jvector.graph.disk.feature.Feature;
3236
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
3337
import io.github.jbellis.jvector.graph.disk.feature.FusedADC;
@@ -45,7 +49,6 @@
4549
import io.github.jbellis.jvector.quantization.PQVectors;
4650
import io.github.jbellis.jvector.quantization.ProductQuantization;
4751
import io.github.jbellis.jvector.quantization.VectorCompressor;
48-
import io.github.jbellis.jvector.util.Bits;
4952
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
5053
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
5154
import io.github.jbellis.jvector.vector.types.VectorFloat;
@@ -59,17 +62,15 @@
5962
import java.nio.file.Files;
6063
import java.nio.file.Path;
6164
import java.nio.file.Paths;
65+
import java.util.ArrayList;
6266
import java.util.EnumMap;
6367
import java.util.HashMap;
6468
import java.util.IdentityHashMap;
6569
import java.util.List;
6670
import java.util.Map;
6771
import java.util.Set;
68-
import java.util.concurrent.atomic.LongAdder;
69-
import java.util.concurrent.ForkJoinPool;
7072
import java.util.function.Function;
7173
import java.util.function.IntFunction;
72-
import java.util.stream.Collectors;
7374
import java.util.stream.IntStream;
7475

7576
/**
@@ -91,7 +92,8 @@ static void runAll(DataSet ds,
9192
List<Function<DataSet, CompressorParameters>> buildCompressors,
9293
List<Function<DataSet, CompressorParameters>> compressionGrid,
9394
Map<Integer, List<Double>> topKGrid,
94-
List<Boolean> usePruningGrid) throws IOException
95+
List<Boolean> usePruningGrid,
96+
Map<String, List<String>> benchmarks) throws IOException
9597
{
9698
var testDirectory = Files.createTempDirectory(dirPrefix);
9799
try {
@@ -102,7 +104,7 @@ static void runAll(DataSet ds,
102104
for (int efC : efConstructionGrid) {
103105
for (var bc : buildCompressors) {
104106
var compressor = getCompressor(bc, ds);
105-
runOneGraph(featureSets, M, efC, neighborOverflow, addHierarchy, refineFinalGraph, compressor, compressionGrid, topKGrid, usePruningGrid, ds, testDirectory);
107+
runOneGraph(featureSets, M, efC, neighborOverflow, addHierarchy, refineFinalGraph, compressor, compressionGrid, topKGrid, usePruningGrid, benchmarks,ds, testDirectory);
106108
}
107109
}
108110
}
@@ -122,6 +124,21 @@ static void runAll(DataSet ds,
122124
}
123125
}
124126

127+
static void runAll(DataSet ds,
128+
List<Integer> mGrid,
129+
List<Integer> efConstructionGrid,
130+
List<Float> neighborOverflowGrid,
131+
List<Boolean> addHierarchyGrid,
132+
List<Boolean> refineFinalGraphGrid,
133+
List<? extends Set<FeatureId>> featureSets,
134+
List<Function<DataSet, CompressorParameters>> buildCompressors,
135+
List<Function<DataSet, CompressorParameters>> compressionGrid,
136+
Map<Integer, List<Double>> topKGrid,
137+
List<Boolean> usePruningGrid) throws IOException
138+
{
139+
runAll(ds, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, refineFinalGraphGrid, featureSets, buildCompressors, compressionGrid, topKGrid, usePruningGrid, null);
140+
}
141+
125142
static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
126143
int M,
127144
int efConstruction,
@@ -132,6 +149,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
132149
List<Function<DataSet, CompressorParameters>> compressionGrid,
133150
Map<Integer, List<Double>> topKGrid,
134151
List<Boolean> usePruningGrid,
152+
Map<String, List<String>> benchmarks,
135153
DataSet ds,
136154
Path testDirectory) throws IOException
137155
{
@@ -158,7 +176,7 @@ static void runOneGraph(List<? extends Set<FeatureId>> featureSets,
158176
indexes.forEach((features, index) -> {
159177
try (var cs = new ConfiguredSystem(ds, index, cv,
160178
index instanceof OnDiskGraphIndex ? ((OnDiskGraphIndex) index).getFeatureSet() : Set.of())) {
161-
testConfiguration(cs, topKGrid, usePruningGrid, M, efConstruction, neighborOverflow, addHierarchy);
179+
testConfiguration(cs, topKGrid, usePruningGrid, M, efConstruction, neighborOverflow, addHierarchy, benchmarks);
162180
} catch (Exception e) {
163181
throw new RuntimeException(e);
164182
}
@@ -379,17 +397,13 @@ private static void testConfiguration(ConfiguredSystem cs,
379397
int M,
380398
int efConstruction,
381399
float neighborOverflow,
382-
boolean addHierarchy) {
400+
boolean addHierarchy,
401+
Map<String, List<String>> benchmarkSpec) {
383402
int queryRuns = 2;
384403
System.out.format("Using %s:%n", cs.index);
385404
// 1) Select benchmarks to run. Use .createDefault or .createEmpty (for other options)
386-
List<QueryBenchmark> benchmarks = List.of(
387-
ThroughputBenchmark.createEmpty(3, 3)
388-
.displayAvgQps(),
389-
LatencyBenchmark.createDefault(),
390-
CountBenchmark.createDefault(),
391-
AccuracyBenchmark.createDefault()
392-
);
405+
406+
var benchmarks = setupBenchmarks(benchmarkSpec);
393407
QueryTester tester = new QueryTester(benchmarks);
394408

395409
// 2) Setup benchmark table for printing
@@ -414,6 +428,85 @@ private static void testConfiguration(ConfiguredSystem cs,
414428
}
415429
}
416430

431+
private static List<QueryBenchmark> setupBenchmarks(Map<String, List<String>> benchmarkSpec) {
432+
if (benchmarkSpec == null || benchmarkSpec.isEmpty()) {
433+
return List.of(
434+
ThroughputBenchmark.createEmpty(3, 3)
435+
.displayAvgQps(),
436+
LatencyBenchmark.createDefault(),
437+
CountBenchmark.createDefault(),
438+
AccuracyBenchmark.createDefault()
439+
);
440+
}
441+
442+
List<QueryBenchmark> benchmarks = new ArrayList<>();
443+
444+
for (var benchType : benchmarkSpec.keySet()) {
445+
if (benchType.equals("throughput")) {
446+
var bench = ThroughputBenchmark.createEmpty(3, 3);
447+
for (var stat : benchmarkSpec.get(benchType)) {
448+
if (stat.equals("AVG")) {
449+
bench = bench.displayAvgQps();
450+
}
451+
if (stat.equals("MEDIAN")) {
452+
bench = bench.displayMedianQps();
453+
}
454+
if (stat.equals("MAX")) {
455+
bench = bench.displayMaxQps();
456+
}
457+
}
458+
benchmarks.add(bench);
459+
}
460+
461+
if (benchType.equals("latency")) {
462+
var bench = LatencyBenchmark.createEmpty();
463+
for (var stat : benchmarkSpec.get(benchType)) {
464+
if (stat.equals("AVG")) {
465+
bench = bench.displayAvgLatency();
466+
}
467+
if (stat.equals("STD")) {
468+
bench = bench.displayLatencySTD();
469+
}
470+
if (stat.equals("P999")) {
471+
bench = bench.displayP999Latency();
472+
}
473+
}
474+
benchmarks.add(bench);
475+
}
476+
477+
if (benchType.equals("count")) {
478+
var bench = CountBenchmark.createEmpty();
479+
for (var stat : benchmarkSpec.get(benchType)) {
480+
if (stat.equals("visited")) {
481+
bench = bench.displayAvgNodesVisited();
482+
}
483+
if (stat.equals("expanded")) {
484+
bench = bench.displayAvgNodesExpanded();
485+
}
486+
if (stat.equals("expanded base layer")) {
487+
bench = bench.displayAvgNodesExpandedBaseLayer();
488+
}
489+
}
490+
benchmarks.add(bench);
491+
}
492+
493+
if (benchType.equals("accuracy")) {
494+
var bench = AccuracyBenchmark.createEmpty();
495+
for (var stat : benchmarkSpec.get(benchType)) {
496+
if (stat.equals("recall")) {
497+
bench = bench.displayRecall();
498+
}
499+
if (stat.equals("MAP")) {
500+
bench = bench.displayMAP();
501+
}
502+
}
503+
benchmarks.add(bench);
504+
}
505+
}
506+
507+
return benchmarks;
508+
}
509+
417510
private static VectorCompressor<?> getCompressor(Function<DataSet, CompressorParameters> cpSupplier, DataSet ds) {
418511
var cp = cpSupplier.apply(ds);
419512
if (!cp.supportsCaching()) {

jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/MultiConfig.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ public class MultiConfig {
3333
public SearchParameters search;
3434

3535
public static MultiConfig getDefaultConfig(String datasetName) throws FileNotFoundException {
36-
File configFile = new File(defaultDirectory + datasetName + ".yml");
36+
var name = defaultDirectory + datasetName;
37+
if (!name.endsWith(".yml")) {
38+
name += ".yml";
39+
}
40+
File configFile = new File(name);
3741
boolean useDefault = !configFile.exists();
3842
if (useDefault) {
3943
configFile = new File(defaultDirectory + "default.yml");
@@ -46,8 +50,8 @@ public static MultiConfig getDefaultConfig(String datasetName) throws FileNotFou
4650
return config;
4751
}
4852

49-
public static MultiConfig getConfig(String datasetName) throws FileNotFoundException {
50-
File configFile = new File(datasetName);
53+
public static MultiConfig getConfig(String configName) throws FileNotFoundException {
54+
File configFile = new File(configName);
5155
return getConfig(configFile);
5256
}
5357

jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/SearchParameters.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
public class SearchParameters extends CommonParameters {
2323
public Map<Integer, List<Double>> topKOverquery;
2424
public List<Boolean> useSearchPruning;
25+
public Map<String, List<String>> benchmarks;
2526
}

jvector-examples/yaml-configs/default.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,9 @@ search:
3131
m: 192
3232
# k: 256 # optional parameter. By default, k=256
3333
centerData: No
34-
anisotropicThreshold: -1.0 # optional parameter. By default, anisotropicThreshold=-1 (i.e., no anisotropy)
34+
anisotropicThreshold: -1.0 # optional parameter. By default, anisotropicThreshold=-1 (i.e., no anisotropy)
35+
benchmarks: # full option set, if the whole "benchmarks" section is not specified, a default set will be used
36+
throughput: [ AVG ] # additional options: [AVG, MEDIAN, MAX]
37+
latency: [ AVG ] # additional options: [ AVG, STD, P999 ]
38+
count: [ visited ] # additional options: [ visited, expanded, expanded base layer ]
39+
accuracy: [ recall ] # additional options: [ recall, MAP ]

0 commit comments

Comments
 (0)