Skip to content

Commit fbd23d5

Browse files
Perf metrics improvement v1.1 (#460)
* Improve configuration of benchmarking suite * Allow to set the precision format and other minor improvements. * Use minimal table by default * Fix javadoc * Fix javadoc * Add missing license * Enable richer setup of the printing format for each metric * Reorder the default table columns
1 parent 1c3c1cd commit fbd23d5

File tree

12 files changed

+265
-351
lines changed

12 files changed

+265
-351
lines changed

jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -376,50 +376,17 @@ private static void testConfiguration(ConfiguredSystem cs,
376376
int queryRuns = 2;
377377
System.out.format("Using %s:%n", cs.index);
378378
// 1) Select benchmarks to run
379-
var benchmarks = List.of(
380-
new ExecutionTimeBenchmark(),
381-
new CountBenchmark(),
382-
new RecallBenchmark(),
379+
List<QueryBenchmark> benchmarks = List.of(
383380
new ThroughputBenchmark(2, 0.1),
384-
new LatencyBenchmark()
381+
new LatencyBenchmark(),
382+
new CountBenchmark(),
383+
new AccuracyBenchmark()
385384
);
386385
QueryTester tester = new QueryTester(benchmarks);
387386

388387
for (var topK : topKGrid) {
389388
for (var usePruning : usePruningGrid) {
390-
// 2) Specify metrics to report. Ensure relevant QueryBenchmark is run first. Required:
391-
// - The String name for each column
392-
// - The relevant BenchmarkSummary.Summary class.
393-
// - The getter for the numerical result
394-
// - The numeric format
395-
List<Metric> metrics = List.of(
396-
Metric.of("QPS",
397-
ThroughputBenchmark.Summary.class,
398-
ThroughputBenchmark.Summary::getQueriesPerSecond,
399-
".1f"),
400-
401-
Metric.of("Avg Visited",
402-
CountBenchmark.Summary.class,
403-
CountBenchmark.Summary::getAvgNodesVisited,
404-
".1f"),
405-
406-
Metric.of("Mean Latency (ms)",
407-
LatencyBenchmark.Summary.class,
408-
LatencyBenchmark.Summary::getAverageLatency,
409-
".3f"),
410-
411-
Metric.of("p999 Latency (ms)",
412-
LatencyBenchmark.Summary.class,
413-
LatencyBenchmark.Summary::getP999Latency,
414-
".3f"),
415-
416-
Metric.of("Recall@" + topK,
417-
RecallBenchmark.Summary.class,
418-
RecallBenchmark.Summary::getAverageRecall,
419-
".2f")
420-
);
421-
422-
BenchmarkTablePrinter printer = new BenchmarkTablePrinter(metrics);
389+
BenchmarkTablePrinter printer = new BenchmarkTablePrinter();
423390
printer.printConfig(Map.of(
424391
"M", M,
425392
"efConstruction", efConstruction,
@@ -430,9 +397,7 @@ private static void testConfiguration(ConfiguredSystem cs,
430397
for (var overquery : efSearchOptions) {
431398
int rerankK = (int) (topK * overquery);
432399

433-
Map<Class<? extends BenchmarkSummary>,BenchmarkSummary> results =
434-
tester.run(cs, topK, rerankK, usePruning, queryRuns);
435-
400+
var results = tester.run(cs, topK, rerankK, usePruning, queryRuns);
436401
printer.printRow(overquery, results);
437402
}
438403
printer.printFooter();

jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/BenchmarkSummary.java renamed to jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AbstractQueryBenchmark.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,4 @@
1616

1717
package io.github.jbellis.jvector.example.benchmarks;
1818

19-
/**
20-
* Marker interface for all benchmark summaries.
21-
*/
22-
public interface BenchmarkSummary { }
23-
19+
public abstract class AbstractQueryBenchmark implements QueryBenchmark {}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.example.benchmarks;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.stream.Collectors;
22+
import java.util.stream.IntStream;
23+
24+
import io.github.jbellis.jvector.example.Grid.ConfiguredSystem;
25+
import io.github.jbellis.jvector.example.util.AccuracyMetrics;
26+
import io.github.jbellis.jvector.graph.SearchResult;
27+
28+
/**
29+
* Measures average recall and/or the mean average precision.
30+
*/
31+
public class AccuracyBenchmark extends AbstractQueryBenchmark {
32+
static private final String DEFAULT_FORMAT = ".2f";
33+
34+
private final boolean computeRecall;
35+
private final boolean computeMAP;
36+
private final String formatRecall;
37+
private final String formatMAP;
38+
39+
public AccuracyBenchmark(boolean computeRecall, boolean computeMAP, String formatRecall, String formatMAP) {
40+
if (!(computeRecall || computeMAP)) {
41+
throw new IllegalArgumentException("At least one parameter must be set to true");
42+
}
43+
this.computeRecall = computeRecall;
44+
this.computeMAP = computeMAP;
45+
this.formatRecall = formatRecall;
46+
this.formatMAP = formatMAP;
47+
}
48+
49+
public AccuracyBenchmark() {
50+
this(true, false, DEFAULT_FORMAT, DEFAULT_FORMAT);
51+
}
52+
53+
public AccuracyBenchmark(String formatRecall) {
54+
this(true, false, formatRecall, DEFAULT_FORMAT);
55+
}
56+
57+
public AccuracyBenchmark(String formatRecall, String formatMAP) {
58+
this(true, true, formatRecall, formatMAP);
59+
}
60+
61+
@Override
62+
public String getBenchmarkName() {
63+
return "RecallBenchmark";
64+
}
65+
66+
@Override
67+
public List<Metric> runBenchmark(
68+
ConfiguredSystem cs,
69+
int topK,
70+
int rerankK,
71+
boolean usePruning,
72+
int queryRuns) {
73+
74+
int totalQueries = cs.getDataSet().queryVectors.size();
75+
76+
// execute all queries in parallel and collect results
77+
List<SearchResult> results = IntStream.range(0, totalQueries)
78+
.parallel()
79+
.mapToObj(i -> QueryExecutor.executeQuery(
80+
cs, topK, rerankK, usePruning, i))
81+
.collect(Collectors.toList());
82+
83+
var list = new ArrayList<Metric>();
84+
if (computeRecall) {
85+
// compute recall for this run
86+
double recall = AccuracyMetrics.recallFromSearchResults(
87+
cs.getDataSet().groundTruth, results, topK, topK
88+
);
89+
list.add(Metric.of("Recall@" + topK, formatRecall, recall));
90+
}
91+
if (computeMAP) {
92+
// compute recall for this run
93+
double map = AccuracyMetrics.meanAveragePrecisionAtK(
94+
cs.getDataSet().groundTruth, results, topK
95+
);
96+
list.add(Metric.of("MAP@" + topK, formatMAP, map));
97+
}
98+
return list;
99+
}
100+
}
101+
102+

jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/BenchmarkTablePrinter.java

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,19 @@ public class BenchmarkTablePrinter {
3030
private static final int MIN_COLUMN_WIDTH = 11;
3131
private static final int MIN_HEADER_PADDING = 3;
3232

33-
private final List<Metric> cols;
34-
private final String headerFmt;
35-
private final String rowFmt;
36-
private boolean headerPrinted = false;
33+
private String headerFmt;
34+
private String rowFmt;
3735

38-
/**
39-
* @param cols the list of Metric definitions, in the order to print columns
40-
*/
41-
public BenchmarkTablePrinter(List<Metric> cols) {
42-
this.cols = cols;
36+
public BenchmarkTablePrinter() {
37+
headerFmt = null;
38+
rowFmt = null;
39+
}
40+
41+
42+
private void initializeHeader(List<Metric> cols) {
43+
if (headerFmt != null) {
44+
return;
45+
}
4346

4447
// Build the format strings for header & rows
4548
StringBuilder hsb = new StringBuilder();
@@ -51,9 +54,9 @@ public BenchmarkTablePrinter(List<Metric> cols) {
5154

5255
// 2) One column per Metric
5356
for (Metric m : cols) {
54-
String hdr = m.getHeader();
55-
String spec = m.getFmtSpec();
56-
int width = Math.max(MIN_COLUMN_WIDTH, hdr.length() + MIN_HEADER_PADDING);
57+
String hdr = m.getHeader();
58+
String spec = m.getFmtSpec();
59+
int width = Math.max(MIN_COLUMN_WIDTH, hdr.length() + MIN_HEADER_PADDING);
5760

5861
// Header: Always a string
5962
hsb.append(" %-").append(width).append("s");
@@ -62,7 +65,10 @@ public BenchmarkTablePrinter(List<Metric> cols) {
6265
}
6366

6467
this.headerFmt = hsb.toString();
65-
this.rowFmt = rsb.append("%n").toString();
68+
this.rowFmt = rsb.append("%n").toString();
69+
70+
System.out.println();
71+
printHeader(cols);
6672
}
6773

6874
/**
@@ -78,7 +84,7 @@ public void printConfig(Map<String, ?> params) {
7884
);
7985
}
8086

81-
private void printHeader() {
87+
private void printHeader(List<Metric> cols) {
8288
// Prepare array: First "Overquery", then each Metric header
8389
Object[] hdrs = new Object[cols.size() + 1];
8490
hdrs[0] = "Overquery";
@@ -99,21 +105,17 @@ private void printHeader() {
99105
* Print a row of data.
100106
*
101107
* @param overquery the first‐column value
102-
* @param results map from Summary.class → summary instance
108+
* @param cols list of metrics to print
103109
*/
104110
public void printRow(double overquery,
105-
Map<Class<? extends BenchmarkSummary>,BenchmarkSummary> results) {
106-
if (!headerPrinted) {
107-
System.out.println();
108-
printHeader();
109-
headerPrinted = true;
110-
}
111+
List<Metric> cols) {
112+
initializeHeader(cols);
111113

112114
// Build argument array: First overquery, then each Metric.extract(...)
113115
Object[] vals = new Object[cols.size() + 1];
114116
vals[0] = overquery;
115117
for (int i = 0; i < cols.size(); i++) {
116-
vals[i + 1] = cols.get(i).extract(results);
118+
vals[i + 1] = cols.get(i).getValue();
117119
}
118120

119121
// Print the formatted row

jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/CountBenchmark.java

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,47 @@
1616

1717
package io.github.jbellis.jvector.example.benchmarks;
1818

19+
import java.util.ArrayList;
20+
import java.util.List;
1921
import java.util.concurrent.atomic.LongAdder;
2022
import java.util.stream.IntStream;
2123

2224
import io.github.jbellis.jvector.example.Grid.ConfiguredSystem;
2325
import io.github.jbellis.jvector.graph.SearchResult;
26+
import org.apache.commons.math3.analysis.function.Abs;
2427

2528
/**
2629
* Measures average node‐visit and node‐expand counts over N runs.
2730
*/
28-
public class CountBenchmark implements QueryBenchmark<CountBenchmark.Summary> {
29-
30-
/**
31-
* Holds the averaged node‐count metrics.
32-
*/
33-
public static class Summary implements BenchmarkSummary {
34-
private final double avgNodesVisited;
35-
private final double avgNodesExpanded;
36-
private final double avgNodesExpandedBaseLayer;
37-
38-
public Summary(double avgNodesVisited,
39-
double avgNodesExpanded,
40-
double avgNodesExpandedBaseLayer) {
41-
this.avgNodesVisited = avgNodesVisited;
42-
this.avgNodesExpanded = avgNodesExpanded;
43-
this.avgNodesExpandedBaseLayer = avgNodesExpandedBaseLayer;
44-
}
45-
46-
@Override
47-
public String toString() {
48-
return String.format(
49-
"CountSummary{%.2f nodes visited (AVG), %.2f nodes expanded, and %.2f nodes expanded in base layer}",
50-
avgNodesVisited,
51-
avgNodesExpanded,
52-
avgNodesExpandedBaseLayer
53-
);
54-
}
55-
56-
public double getAvgNodesVisited() {
57-
return avgNodesVisited;
31+
public class CountBenchmark extends AbstractQueryBenchmark {
32+
static private final String DEFAULT_FORMAT = ".1f";
33+
34+
private final boolean computeAvgNodesVisited;
35+
private final boolean computeAvgNodesExpanded;
36+
private final boolean computeAvgNodesExpandedBaseLayer;
37+
private final String formatAvgNodesVisited;
38+
private final String formatAvgNodesExpanded;
39+
private final String formatAvgNodesExpandedBaseLayer;
40+
41+
public CountBenchmark(boolean computeAvgNodesVisited, boolean computeAvgNodesExpanded, boolean computeAvgNodesExpandedBaseLayer,
42+
String formatAvgNodesVisited, String formatAvgNodesExpanded, String formatAvgNodesExpandedBaseLayer) {
43+
if (!(computeAvgNodesVisited || computeAvgNodesExpanded || computeAvgNodesExpandedBaseLayer)) {
44+
throw new IllegalArgumentException("At least one parameter must be set to true");
5845
}
46+
this.computeAvgNodesVisited = computeAvgNodesVisited;
47+
this.computeAvgNodesExpanded = computeAvgNodesExpanded;
48+
this.computeAvgNodesExpandedBaseLayer = computeAvgNodesExpandedBaseLayer;
49+
this.formatAvgNodesVisited = formatAvgNodesVisited;
50+
this.formatAvgNodesExpanded = formatAvgNodesExpanded;
51+
this.formatAvgNodesExpandedBaseLayer = formatAvgNodesExpandedBaseLayer;
52+
}
5953

60-
public double getAvgNodesExpanded() {
61-
return avgNodesExpanded;
62-
}
54+
public CountBenchmark() {
55+
this(true, false, false, DEFAULT_FORMAT, DEFAULT_FORMAT, DEFAULT_FORMAT);
56+
}
6357

64-
public double getAvgNodesExpandedBaseLayer() {
65-
return avgNodesExpandedBaseLayer;
66-
}
58+
public CountBenchmark(String formatAvgNodesVisited, String formatAvgNodesExpanded, String formatAvgNodesExpandedBaseLayer) {
59+
this(true, true, true, formatAvgNodesVisited, formatAvgNodesExpanded, formatAvgNodesExpandedBaseLayer);
6760
}
6861

6962
@Override
@@ -72,7 +65,7 @@ public String getBenchmarkName() {
7265
}
7366

7467
@Override
75-
public Summary runBenchmark(
68+
public List<Metric> runBenchmark(
7669
ConfiguredSystem cs,
7770
int topK,
7871
int rerankK,
@@ -100,8 +93,16 @@ public Summary runBenchmark(
10093
double avgExpanded = nodesExpanded.sum() / (double) (queryRuns * totalQueries);
10194
double avgBase = nodesExpandedBaseLayer.sum() / (double) (queryRuns * totalQueries);
10295

103-
return new Summary(avgVisited, avgExpanded, avgBase);
96+
var list = new ArrayList<Metric>();
97+
if (computeAvgNodesVisited) {
98+
list.add(Metric.of("Avg Visited", formatAvgNodesVisited, avgVisited));
99+
}
100+
if (computeAvgNodesExpanded) {
101+
list.add(Metric.of("Avg Expanded", formatAvgNodesExpanded, avgExpanded));
102+
}
103+
if (computeAvgNodesExpandedBaseLayer) {
104+
list.add(Metric.of("Avg Expanded Base Layer", formatAvgNodesExpandedBaseLayer, avgBase));
105+
}
106+
return list;
104107
}
105-
}
106-
107-
108+
}

0 commit comments

Comments
 (0)