Skip to content

Commit ace1dbe

Browse files
committed
Add nProbe to :qa:vector:checkVec and allow multiple nProbes
1 parent 23cd462 commit ace1dbe

File tree

3 files changed

+69
-53
lines changed

3 files changed

+69
-53
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.IOException;
2323
import java.nio.file.Path;
24+
import java.util.List;
2425
import java.util.Locale;
2526

2627
/**
@@ -35,7 +36,7 @@ record CmdLineArgs(
3536
KnnIndexTester.IndexType indexType,
3637
int numCandidates,
3738
int k,
38-
int nProbe,
39+
int[] nProbes,
3940
int ivfClusterSize,
4041
int overSamplingFactor,
4142
int hnswM,
@@ -86,7 +87,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
8687
PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
8788
PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
8889
PARSER.declareInt(Builder::setK, K_FIELD);
89-
PARSER.declareInt(Builder::setNProbe, N_PROBE_FIELD);
90+
PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
9091
PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
9192
PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
9293
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
@@ -115,7 +116,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
115116
builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
116117
builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
117118
builder.field(K_FIELD.getPreferredName(), k);
118-
builder.field(N_PROBE_FIELD.getPreferredName(), nProbe);
119+
builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
119120
builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
120121
builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
121122
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
@@ -144,7 +145,7 @@ static class Builder {
144145
private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
145146
private int numCandidates = 1000;
146147
private int k = 10;
147-
private int nProbe = 10;
148+
private int[] nProbes = new int[] { 10 };
148149
private int ivfClusterSize = 1000;
149150
private int overSamplingFactor = 1;
150151
private int hnswM = 16;
@@ -193,8 +194,8 @@ public Builder setK(int k) {
193194
return this;
194195
}
195196

196-
public Builder setNProbe(int nProbe) {
197-
this.nProbe = nProbe;
197+
public Builder setNProbe(List<Integer> nProbes) {
198+
this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
198199
return this;
199200
}
200201

@@ -275,7 +276,7 @@ public CmdLineArgs build() {
275276
indexType,
276277
numCandidates,
277278
k,
278-
nProbe,
279+
nProbes,
279280
ivfClusterSize,
280281
overSamplingFactor,
281282
hnswM,

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,15 @@ public static void main(String[] args) throws Exception {
172172
}
173173
}
174174
FormattedResults formattedResults = new FormattedResults();
175+
175176
for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
176-
Results result = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
177+
int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
178+
? cmdLineArgs.nProbes()
179+
: new int[] { 0 };
180+
Results[] results = new Results[nProbes.length];
181+
for (int i = 0; i < nProbes.length; i++) {
182+
results[i] = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
183+
}
177184
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
178185
Codec codec = createCodec(cmdLineArgs);
179186
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
@@ -192,19 +199,22 @@ public static void main(String[] args) throws Exception {
192199
throw new IllegalArgumentException("Index path does not exist: " + indexPath);
193200
}
194201
if (cmdLineArgs.reindex()) {
195-
knnIndexer.createIndex(result);
202+
knnIndexer.createIndex(results[0]);
196203
}
197204
if (cmdLineArgs.forceMerge()) {
198-
knnIndexer.forceMerge(result);
205+
knnIndexer.forceMerge(results[0]);
199206
} else {
200-
knnIndexer.numSegments(result);
207+
knnIndexer.numSegments(results[0]);
201208
}
202209
}
203210
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
204-
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
205-
knnSearcher.runSearch(result);
211+
for (int i = 0; i < results.length; i++) {
212+
int nProbe = nProbes[i];
213+
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
214+
knnSearcher.runSearch(results[i]);
215+
}
206216
}
207-
formattedResults.results.add(result);
217+
formattedResults.results.addAll(List.of(results));
208218
}
209219
logger.info("Results: \n" + formattedResults);
210220
}
@@ -218,13 +228,12 @@ public String toString() {
218228
return "No results available.";
219229
}
220230

231+
String[] indexingHeaders = { "index_type", "num_docs", "index_time(ms)", "force_merge_time(ms)", "num_segments" };
232+
221233
// Define column headers
222-
String[] headers = {
234+
String[] searchHeaders = {
223235
"index_type",
224-
"num_docs",
225-
"index_time(ms)",
226-
"force_merge_time(ms)",
227-
"num_segments",
236+
"n_probe",
228237
"latency(ms)",
229238
"net_cpu_time(ms)",
230239
"avg_cpu_count",
@@ -233,41 +242,58 @@ public String toString() {
233242
"visited" };
234243

235244
// Calculate appropriate column widths based on headers and data
236-
int[] widths = calculateColumnWidths(headers);
237245

238246
StringBuilder sb = new StringBuilder();
239247

240-
// Format and append header
241-
sb.append(formatRow(headers, widths));
242-
sb.append("\n");
248+
Results indexResult = results.get(0); // Assuming all results have the same index type and numDocs
249+
String[] indexData = {
250+
indexResult.indexType,
251+
Integer.toString(indexResult.numDocs),
252+
Long.toString(indexResult.indexTimeMS),
253+
Long.toString(indexResult.forceMergeTimeMS),
254+
Integer.toString(indexResult.numSegments) };
243255

244-
// Add separator line
245-
for (int width : widths) {
246-
sb.append("-".repeat(width)).append(" ");
247-
}
248-
sb.append("\n");
256+
printBlock(sb, indexingHeaders, new String[][] { indexData });
249257

258+
String[][] searchData = new String[results.size()][];
250259
// Format and append each row of data
251-
for (Results result : results) {
252-
String[] rowData = {
260+
for (int i = 0; i < results.size(); i++) {
261+
Results result = results.get(i);
262+
searchData[i] = new String[] {
253263
result.indexType,
254-
Integer.toString(result.numDocs),
255-
Long.toString(result.indexTimeMS),
256-
Long.toString(result.forceMergeTimeMS),
257-
Integer.toString(result.numSegments),
264+
Integer.toString(result.nProbe),
258265
String.format(Locale.ROOT, "%.2f", result.avgLatency),
259266
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
260267
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
261268
String.format(Locale.ROOT, "%.2f", result.qps),
262269
String.format(Locale.ROOT, "%.2f", result.avgRecall),
263270
String.format(Locale.ROOT, "%.2f", result.averageVisited) };
264-
sb.append(formatRow(rowData, widths));
265-
sb.append("\n");
271+
266272
}
267273

274+
printBlock(sb, searchHeaders, searchData);
275+
268276
return sb.toString();
269277
}
270278

279+
private void printBlock(StringBuilder sb, String[] headers, String[][] rows) {
280+
int[] widths = calculateColumnWidths(headers, rows);
281+
sb.append("\n");
282+
sb.append(formatRow(headers, widths));
283+
sb.append("\n");
284+
285+
// Add separator line
286+
for (int width : widths) {
287+
sb.append("-".repeat(width)).append(" ");
288+
}
289+
sb.append("\n");
290+
291+
for (String[] row : rows) {
292+
sb.append(formatRow(row, widths));
293+
sb.append("\n");
294+
}
295+
}
296+
271297
// Helper method to format a single row with proper column widths
272298
private String formatRow(String[] values, int[] widths) {
273299
StringBuilder row = new StringBuilder();
@@ -285,7 +311,7 @@ private String formatRow(String[] values, int[] widths) {
285311
}
286312

287313
// Calculate appropriate column widths based on headers and data
288-
private int[] calculateColumnWidths(String[] headers) {
314+
private int[] calculateColumnWidths(String[] headers, String[]... data) {
289315
int[] widths = new int[headers.length];
290316

291317
// Initialize widths with header lengths
@@ -294,20 +320,7 @@ private int[] calculateColumnWidths(String[] headers) {
294320
}
295321

296322
// Update widths based on data
297-
for (Results result : results) {
298-
String[] values = {
299-
result.indexType,
300-
Integer.toString(result.numDocs),
301-
Long.toString(result.indexTimeMS),
302-
Long.toString(result.forceMergeTimeMS),
303-
Integer.toString(result.numSegments),
304-
String.format(Locale.ROOT, "%.2f", result.avgLatency),
305-
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
306-
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
307-
String.format(Locale.ROOT, "%.2f", result.qps),
308-
String.format(Locale.ROOT, "%.2f", result.avgRecall),
309-
String.format(Locale.ROOT, "%.2f", result.averageVisited) };
310-
323+
for (String[] values : data) {
311324
for (int i = 0; i < values.length; i++) {
312325
widths[i] = Math.max(widths[i], values[i].length());
313326
}
@@ -323,6 +336,7 @@ static class Results {
323336
long indexTimeMS;
324337
long forceMergeTimeMS;
325338
int numSegments;
339+
int nProbe;
326340
double avgLatency;
327341
double qps;
328342
double avgRecall;

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class KnnSearcher {
9494
private final float overSamplingFactor;
9595
private final int searchThreads;
9696

97-
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs) {
97+
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
9898
this.docPath = cmdLineArgs.docVectors();
9999
this.indexPath = indexPath;
100100
this.queryPath = cmdLineArgs.queryVectors();
@@ -109,7 +109,7 @@ class KnnSearcher {
109109
throw new IllegalArgumentException("numQueryVectors must be > 0");
110110
}
111111
this.efSearch = cmdLineArgs.numCandidates();
112-
this.nProbe = cmdLineArgs.nProbe();
112+
this.nProbe = nProbe;
113113
this.indexType = cmdLineArgs.indexType();
114114
this.searchThreads = cmdLineArgs.searchThreads();
115115
}
@@ -206,6 +206,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
206206
}
207207
logger.info("checking results");
208208
int[][] nn = getOrCalculateExactNN(offsetByteSize);
209+
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
209210
finalResults.avgRecall = checkResults(resultIds, nn, topK);
210211
finalResults.qps = (1000f * numQueryVectors) / elapsed;
211212
finalResults.avgLatency = (float) elapsed / numQueryVectors;

0 commit comments

Comments
 (0)