Skip to content

Commit 7bc215a

Browse files
authored
Add nProbe to :qa:vector:checkVec and allow multiple nProbes (#130316)
This change adds the n_probe value to the output which will be 0 in the case of non-ivf runs. In addition it separates index and search data, so a normal output looks like: ``` index_type num_docs index_time(ms) force_merge_time(ms) num_segments ---------- -------- -------------- -------------------- ------------ ivf 1000000 50382 132819 0 index_type n_probe latency(ms) net_cpu_time(ms) avg_cpu_count QPS recall visited ---------- ------- ----------- ---------------- ------------- ------ ------ -------- ivf 100 3.69 0.00 0.00 271.00 0.97 58917.00 ``` In addition, this change allows to define an array of n_probe in the configuration file so we can test different values in the same run, so for example defining an n_probe like: ``` "n_probe" : [10, 20, 30, 40, 50, 60, 70, 80, 90, 100], ``` will produce the following output: ``` index_type num_docs index_time(ms) force_merge_time(ms) num_segments ---------- -------- -------------- -------------------- ------------ ivf 1000000 50382 132819 0 index_type n_probe latency(ms) net_cpu_time(ms) avg_cpu_count QPS recall visited ---------- ------- ----------- ---------------- ------------- ------ ------ -------- ivf 10 1.18 0.00 0.00 847.46 0.82 7244.59 ivf 20 1.36 0.00 0.00 735.29 0.89 13288.69 ivf 30 1.66 0.00 0.00 602.41 0.92 19266.67 ivf 40 1.93 0.00 0.00 518.13 0.94 24995.41 ivf 50 2.21 0.00 0.00 452.49 0.94 30739.60 ivf 60 2.51 0.00 0.00 398.41 0.95 36428.00 ivf 70 2.76 0.00 0.00 362.32 0.96 41952.59 ivf 80 2.99 0.00 0.00 334.45 0.96 47599.64 ivf 90 3.31 0.00 0.00 302.11 0.96 53254.45 ivf 100 3.69 0.00 0.00 271.00 0.97 58917.00 ``` This makes easier to plot the n_probe curve while doing changes.
1 parent 5b9b033 commit 7bc215a

File tree

3 files changed

+69
-53
lines changed

3 files changed

+69
-53
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.IOException;
2323
import java.nio.file.Path;
24+
import java.util.List;
2425
import java.util.Locale;
2526

2627
/**
@@ -35,7 +36,7 @@ record CmdLineArgs(
3536
KnnIndexTester.IndexType indexType,
3637
int numCandidates,
3738
int k,
38-
int nProbe,
39+
int[] nProbes,
3940
int ivfClusterSize,
4041
int overSamplingFactor,
4142
int hnswM,
@@ -86,7 +87,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
8687
PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
8788
PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
8889
PARSER.declareInt(Builder::setK, K_FIELD);
89-
PARSER.declareInt(Builder::setNProbe, N_PROBE_FIELD);
90+
PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
9091
PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
9192
PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
9293
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
@@ -115,7 +116,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
115116
builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
116117
builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
117118
builder.field(K_FIELD.getPreferredName(), k);
118-
builder.field(N_PROBE_FIELD.getPreferredName(), nProbe);
119+
builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
119120
builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
120121
builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
121122
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
@@ -144,7 +145,7 @@ static class Builder {
144145
private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
145146
private int numCandidates = 1000;
146147
private int k = 10;
147-
private int nProbe = 10;
148+
private int[] nProbes = new int[] { 10 };
148149
private int ivfClusterSize = 1000;
149150
private int overSamplingFactor = 1;
150151
private int hnswM = 16;
@@ -193,8 +194,8 @@ public Builder setK(int k) {
193194
return this;
194195
}
195196

196-
public Builder setNProbe(int nProbe) {
197-
this.nProbe = nProbe;
197+
public Builder setNProbe(List<Integer> nProbes) {
198+
this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
198199
return this;
199200
}
200201

@@ -275,7 +276,7 @@ public CmdLineArgs build() {
275276
indexType,
276277
numCandidates,
277278
k,
278-
nProbe,
279+
nProbes,
279280
ivfClusterSize,
280281
overSamplingFactor,
281282
hnswM,

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,15 @@ public static void main(String[] args) throws Exception {
172172
}
173173
}
174174
FormattedResults formattedResults = new FormattedResults();
175+
175176
for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
176-
Results result = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
177+
int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
178+
? cmdLineArgs.nProbes()
179+
: new int[] { 0 };
180+
Results[] results = new Results[nProbes.length];
181+
for (int i = 0; i < nProbes.length; i++) {
182+
results[i] = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
183+
}
177184
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
178185
Codec codec = createCodec(cmdLineArgs);
179186
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
@@ -192,19 +199,22 @@ public static void main(String[] args) throws Exception {
192199
throw new IllegalArgumentException("Index path does not exist: " + indexPath);
193200
}
194201
if (cmdLineArgs.reindex()) {
195-
knnIndexer.createIndex(result);
202+
knnIndexer.createIndex(results[0]);
196203
}
197204
if (cmdLineArgs.forceMerge()) {
198-
knnIndexer.forceMerge(result);
205+
knnIndexer.forceMerge(results[0]);
199206
} else {
200-
knnIndexer.numSegments(result);
207+
knnIndexer.numSegments(results[0]);
201208
}
202209
}
203210
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
204-
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
205-
knnSearcher.runSearch(result);
211+
for (int i = 0; i < results.length; i++) {
212+
int nProbe = nProbes[i];
213+
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
214+
knnSearcher.runSearch(results[i]);
215+
}
206216
}
207-
formattedResults.results.add(result);
217+
formattedResults.results.addAll(List.of(results));
208218
}
209219
logger.info("Results: \n" + formattedResults);
210220
}
@@ -218,13 +228,12 @@ public String toString() {
218228
return "No results available.";
219229
}
220230

231+
String[] indexingHeaders = { "index_type", "num_docs", "index_time(ms)", "force_merge_time(ms)", "num_segments" };
232+
221233
// Define column headers
222-
String[] headers = {
234+
String[] searchHeaders = {
223235
"index_type",
224-
"num_docs",
225-
"index_time(ms)",
226-
"force_merge_time(ms)",
227-
"num_segments",
236+
"n_probe",
228237
"latency(ms)",
229238
"net_cpu_time(ms)",
230239
"avg_cpu_count",
@@ -233,41 +242,58 @@ public String toString() {
233242
"visited" };
234243

235244
// Calculate appropriate column widths based on headers and data
236-
int[] widths = calculateColumnWidths(headers);
237245

238246
StringBuilder sb = new StringBuilder();
239247

240-
// Format and append header
241-
sb.append(formatRow(headers, widths));
242-
sb.append("\n");
248+
Results indexResult = results.get(0); // Assuming all results have the same index type and numDocs
249+
String[] indexData = {
250+
indexResult.indexType,
251+
Integer.toString(indexResult.numDocs),
252+
Long.toString(indexResult.indexTimeMS),
253+
Long.toString(indexResult.forceMergeTimeMS),
254+
Integer.toString(indexResult.numSegments) };
243255

244-
// Add separator line
245-
for (int width : widths) {
246-
sb.append("-".repeat(width)).append(" ");
247-
}
248-
sb.append("\n");
256+
printBlock(sb, indexingHeaders, new String[][] { indexData });
249257

258+
String[][] searchData = new String[results.size()][];
250259
// Format and append each row of data
251-
for (Results result : results) {
252-
String[] rowData = {
260+
for (int i = 0; i < results.size(); i++) {
261+
Results result = results.get(i);
262+
searchData[i] = new String[] {
253263
result.indexType,
254-
Integer.toString(result.numDocs),
255-
Long.toString(result.indexTimeMS),
256-
Long.toString(result.forceMergeTimeMS),
257-
Integer.toString(result.numSegments),
264+
Integer.toString(result.nProbe),
258265
String.format(Locale.ROOT, "%.2f", result.avgLatency),
259266
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
260267
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
261268
String.format(Locale.ROOT, "%.2f", result.qps),
262269
String.format(Locale.ROOT, "%.2f", result.avgRecall),
263270
String.format(Locale.ROOT, "%.2f", result.averageVisited) };
264-
sb.append(formatRow(rowData, widths));
265-
sb.append("\n");
271+
266272
}
267273

274+
printBlock(sb, searchHeaders, searchData);
275+
268276
return sb.toString();
269277
}
270278

279+
private void printBlock(StringBuilder sb, String[] headers, String[][] rows) {
280+
int[] widths = calculateColumnWidths(headers, rows);
281+
sb.append("\n");
282+
sb.append(formatRow(headers, widths));
283+
sb.append("\n");
284+
285+
// Add separator line
286+
for (int width : widths) {
287+
sb.append("-".repeat(width)).append(" ");
288+
}
289+
sb.append("\n");
290+
291+
for (String[] row : rows) {
292+
sb.append(formatRow(row, widths));
293+
sb.append("\n");
294+
}
295+
}
296+
271297
// Helper method to format a single row with proper column widths
272298
private String formatRow(String[] values, int[] widths) {
273299
StringBuilder row = new StringBuilder();
@@ -285,7 +311,7 @@ private String formatRow(String[] values, int[] widths) {
285311
}
286312

287313
// Calculate appropriate column widths based on headers and data
288-
private int[] calculateColumnWidths(String[] headers) {
314+
private int[] calculateColumnWidths(String[] headers, String[]... data) {
289315
int[] widths = new int[headers.length];
290316

291317
// Initialize widths with header lengths
@@ -294,20 +320,7 @@ private int[] calculateColumnWidths(String[] headers) {
294320
}
295321

296322
// Update widths based on data
297-
for (Results result : results) {
298-
String[] values = {
299-
result.indexType,
300-
Integer.toString(result.numDocs),
301-
Long.toString(result.indexTimeMS),
302-
Long.toString(result.forceMergeTimeMS),
303-
Integer.toString(result.numSegments),
304-
String.format(Locale.ROOT, "%.2f", result.avgLatency),
305-
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
306-
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
307-
String.format(Locale.ROOT, "%.2f", result.qps),
308-
String.format(Locale.ROOT, "%.2f", result.avgRecall),
309-
String.format(Locale.ROOT, "%.2f", result.averageVisited) };
310-
323+
for (String[] values : data) {
311324
for (int i = 0; i < values.length; i++) {
312325
widths[i] = Math.max(widths[i], values[i].length());
313326
}
@@ -323,6 +336,7 @@ static class Results {
323336
long indexTimeMS;
324337
long forceMergeTimeMS;
325338
int numSegments;
339+
int nProbe;
326340
double avgLatency;
327341
double qps;
328342
double avgRecall;

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class KnnSearcher {
9494
private final float overSamplingFactor;
9595
private final int searchThreads;
9696

97-
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs) {
97+
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
9898
this.docPath = cmdLineArgs.docVectors();
9999
this.indexPath = indexPath;
100100
this.queryPath = cmdLineArgs.queryVectors();
@@ -109,7 +109,7 @@ class KnnSearcher {
109109
throw new IllegalArgumentException("numQueryVectors must be > 0");
110110
}
111111
this.efSearch = cmdLineArgs.numCandidates();
112-
this.nProbe = cmdLineArgs.nProbe();
112+
this.nProbe = nProbe;
113113
this.indexType = cmdLineArgs.indexType();
114114
this.searchThreads = cmdLineArgs.searchThreads();
115115
}
@@ -206,6 +206,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
206206
}
207207
logger.info("checking results");
208208
int[][] nn = getOrCalculateExactNN(offsetByteSize);
209+
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
209210
finalResults.avgRecall = checkResults(resultIds, nn, topK);
210211
finalResults.qps = (1000f * numQueryVectors) / elapsed;
211212
finalResults.avgLatency = (float) elapsed / numQueryVectors;

0 commit comments

Comments
 (0)