From ace1dbe82348f948f4b2ebde75289c82a3d7234f Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 30 Jun 2025 11:43:24 +0100 Subject: [PATCH] Add nProbe to `:qa:vector:checkVec` and allow multiple nProbes --- .../elasticsearch/test/knn/CmdLineArgs.java | 15 +-- .../test/knn/KnnIndexTester.java | 102 ++++++++++-------- .../elasticsearch/test/knn/KnnSearcher.java | 5 +- 3 files changed, 69 insertions(+), 53 deletions(-) diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 61113866c9f56..037438069dade 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.file.Path; +import java.util.List; import java.util.Locale; /** @@ -35,7 +36,7 @@ record CmdLineArgs( KnnIndexTester.IndexType indexType, int numCandidates, int k, - int nProbe, + int[] nProbes, int ivfClusterSize, int overSamplingFactor, int hnswM, @@ -86,7 +87,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD); PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD); PARSER.declareInt(Builder::setK, K_FIELD); - PARSER.declareInt(Builder::setNProbe, N_PROBE_FIELD); + PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD); PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD); PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD); PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD); @@ -115,7 +116,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT)); builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates); builder.field(K_FIELD.getPreferredName(), k); - builder.field(N_PROBE_FIELD.getPreferredName(), nProbe); + builder.field(N_PROBE_FIELD.getPreferredName(), nProbes); builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize); builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor); builder.field(HNSW_M_FIELD.getPreferredName(), hnswM); @@ -144,7 +145,7 @@ static class Builder { private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW; private int numCandidates = 1000; private int k = 10; - private int nProbe = 10; + private int[] nProbes = new int[] { 10 }; private int ivfClusterSize = 1000; private int overSamplingFactor = 1; private int hnswM = 16; @@ -193,8 +194,8 @@ public Builder setK(int k) { return this; } - public Builder setNProbe(int nProbe) { - this.nProbe = nProbe; + public Builder setNProbe(List nProbes) { + this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray(); return this; } @@ -275,7 +276,7 @@ public CmdLineArgs build() { indexType, numCandidates, k, - nProbe, + nProbes, ivfClusterSize, overSamplingFactor, hnswM, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 685f88372c709..25525fe40f92c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -172,8 +172,15 @@ public static void main(String[] args) throws Exception { } } FormattedResults formattedResults = new FormattedResults(); + for (CmdLineArgs cmdLineArgs : cmdLineArgsList) { - Results result = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs()); + int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0 + ? cmdLineArgs.nProbes() + : new int[] { 0 }; + Results[] results = new Results[nProbes.length]; + for (int i = 0; i < nProbes.length; i++) { + results[i] = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs()); + } logger.info("Running KNN index tester with arguments: " + cmdLineArgs); Codec codec = createCodec(cmdLineArgs); Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs)); @@ -192,19 +199,22 @@ public static void main(String[] args) throws Exception { throw new IllegalArgumentException("Index path does not exist: " + indexPath); } if (cmdLineArgs.reindex()) { - knnIndexer.createIndex(result); + knnIndexer.createIndex(results[0]); } if (cmdLineArgs.forceMerge()) { - knnIndexer.forceMerge(result); + knnIndexer.forceMerge(results[0]); } else { - knnIndexer.numSegments(result); + knnIndexer.numSegments(results[0]); } } if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) { - KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs); - knnSearcher.runSearch(result); + for (int i = 0; i < results.length; i++) { + int nProbe = nProbes[i]; + KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe); + knnSearcher.runSearch(results[i]); + } } - formattedResults.results.add(result); + formattedResults.results.addAll(List.of(results)); } logger.info("Results: \n" + formattedResults); } @@ -218,13 +228,12 @@ public String toString() { return "No results available."; } + String[] indexingHeaders = { "index_type", "num_docs", "index_time(ms)", "force_merge_time(ms)", "num_segments" }; + // Define column headers - String[] headers = { + String[] searchHeaders = { "index_type", - "num_docs", - "index_time(ms)", - "force_merge_time(ms)", - "num_segments", + "n_probe", "latency(ms)", "net_cpu_time(ms)", "avg_cpu_count", @@ -233,41 +242,58 @@ public String toString() { "visited" }; // Calculate appropriate column widths based on headers and data - int[] widths = calculateColumnWidths(headers); StringBuilder sb = new StringBuilder(); - // Format and append header - sb.append(formatRow(headers, widths)); - sb.append("\n"); + Results indexResult = results.get(0); // Assuming all results have the same index type and numDocs + String[] indexData = { + indexResult.indexType, + Integer.toString(indexResult.numDocs), + Long.toString(indexResult.indexTimeMS), + Long.toString(indexResult.forceMergeTimeMS), + Integer.toString(indexResult.numSegments) }; - // Add separator line - for (int width : widths) { - sb.append("-".repeat(width)).append(" "); - } - sb.append("\n"); + printBlock(sb, indexingHeaders, new String[][] { indexData }); + String[][] searchData = new String[results.size()][]; // Format and append each row of data - for (Results result : results) { - String[] rowData = { + for (int i = 0; i < results.size(); i++) { + Results result = results.get(i); + searchData[i] = new String[] { result.indexType, - Integer.toString(result.numDocs), - Long.toString(result.indexTimeMS), - Long.toString(result.forceMergeTimeMS), - Integer.toString(result.numSegments), + Integer.toString(result.nProbe), String.format(Locale.ROOT, "%.2f", result.avgLatency), String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS), String.format(Locale.ROOT, "%.2f", result.avgCpuCount), String.format(Locale.ROOT, "%.2f", result.qps), String.format(Locale.ROOT, "%.2f", result.avgRecall), String.format(Locale.ROOT, "%.2f", result.averageVisited) }; - sb.append(formatRow(rowData, widths)); - sb.append("\n"); + } + printBlock(sb, searchHeaders, searchData); + return sb.toString(); } + private void printBlock(StringBuilder sb, String[] headers, String[][] rows) { + int[] widths = calculateColumnWidths(headers, rows); + sb.append("\n"); + sb.append(formatRow(headers, widths)); + sb.append("\n"); + + // Add separator line + for (int width : widths) { + sb.append("-".repeat(width)).append(" "); + } + sb.append("\n"); + + for (String[] row : rows) { + sb.append(formatRow(row, widths)); + sb.append("\n"); + } + } + // Helper method to format a single row with proper column widths private String formatRow(String[] values, int[] widths) { StringBuilder row = new StringBuilder(); @@ -285,7 +311,7 @@ private String formatRow(String[] values, int[] widths) { } // Calculate appropriate column widths based on headers and data - private int[] calculateColumnWidths(String[] headers) { + private int[] calculateColumnWidths(String[] headers, String[]... data) { int[] widths = new int[headers.length]; // Initialize widths with header lengths @@ -294,20 +320,7 @@ private int[] calculateColumnWidths(String[] headers) { } // Update widths based on data - for (Results result : results) { - String[] values = { - result.indexType, - Integer.toString(result.numDocs), - Long.toString(result.indexTimeMS), - Long.toString(result.forceMergeTimeMS), - Integer.toString(result.numSegments), - String.format(Locale.ROOT, "%.2f", result.avgLatency), - String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS), - String.format(Locale.ROOT, "%.2f", result.avgCpuCount), - String.format(Locale.ROOT, "%.2f", result.qps), - String.format(Locale.ROOT, "%.2f", result.avgRecall), - String.format(Locale.ROOT, "%.2f", result.averageVisited) }; - + for (String[] values : data) { for (int i = 0; i < values.length; i++) { widths[i] = Math.max(widths[i], values[i].length()); } @@ -323,6 +336,7 @@ static class Results { long indexTimeMS; long forceMergeTimeMS; int numSegments; + int nProbe; double avgLatency; double qps; double avgRecall; diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 7dd6f2894a20a..1f6eac89331b3 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -94,7 +94,7 @@ class KnnSearcher { private final float overSamplingFactor; private final int searchThreads; - KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs) { + KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) { this.docPath = cmdLineArgs.docVectors(); this.indexPath = indexPath; this.queryPath = cmdLineArgs.queryVectors(); @@ -109,7 +109,7 @@ class KnnSearcher { throw new IllegalArgumentException("numQueryVectors must be > 0"); } this.efSearch = cmdLineArgs.numCandidates(); - this.nProbe = cmdLineArgs.nProbe(); + this.nProbe = nProbe; this.indexType = cmdLineArgs.indexType(); this.searchThreads = cmdLineArgs.searchThreads(); } @@ -206,6 +206,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException { } logger.info("checking results"); int[][] nn = getOrCalculateExactNN(offsetByteSize); + finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0; finalResults.avgRecall = checkResults(resultIds, nn, topK); finalResults.qps = (1000f * numQueryVectors) / elapsed; finalResults.avgLatency = (float) elapsed / numQueryVectors;