Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;

/**
Expand All @@ -35,7 +36,7 @@ record CmdLineArgs(
KnnIndexTester.IndexType indexType,
int numCandidates,
int k,
int nProbe,
int[] nProbes,
int ivfClusterSize,
int overSamplingFactor,
int hnswM,
Expand Down Expand Up @@ -86,7 +87,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
PARSER.declareInt(Builder::setK, K_FIELD);
PARSER.declareInt(Builder::setNProbe, N_PROBE_FIELD);
PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
Expand Down Expand Up @@ -115,7 +116,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
builder.field(K_FIELD.getPreferredName(), k);
builder.field(N_PROBE_FIELD.getPreferredName(), nProbe);
builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
Expand Down Expand Up @@ -144,7 +145,7 @@ static class Builder {
private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
private int numCandidates = 1000;
private int k = 10;
private int nProbe = 10;
private int[] nProbes = new int[] { 10 };
private int ivfClusterSize = 1000;
private int overSamplingFactor = 1;
private int hnswM = 16;
Expand Down Expand Up @@ -193,8 +194,8 @@ public Builder setK(int k) {
return this;
}

public Builder setNProbe(int nProbe) {
this.nProbe = nProbe;
public Builder setNProbe(List<Integer> nProbes) {
this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
return this;
}

Expand Down Expand Up @@ -275,7 +276,7 @@ public CmdLineArgs build() {
indexType,
numCandidates,
k,
nProbe,
nProbes,
ivfClusterSize,
overSamplingFactor,
hnswM,
Expand Down
102 changes: 58 additions & 44 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,15 @@ public static void main(String[] args) throws Exception {
}
}
FormattedResults formattedResults = new FormattedResults();

for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
Results result = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
? cmdLineArgs.nProbes()
: new int[] { 0 };
Results[] results = new Results[nProbes.length];
for (int i = 0; i < nProbes.length; i++) {
results[i] = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
}
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
Codec codec = createCodec(cmdLineArgs);
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
Expand All @@ -192,19 +199,22 @@ public static void main(String[] args) throws Exception {
throw new IllegalArgumentException("Index path does not exist: " + indexPath);
}
if (cmdLineArgs.reindex()) {
knnIndexer.createIndex(result);
knnIndexer.createIndex(results[0]);
}
if (cmdLineArgs.forceMerge()) {
knnIndexer.forceMerge(result);
knnIndexer.forceMerge(results[0]);
} else {
knnIndexer.numSegments(result);
knnIndexer.numSegments(results[0]);
}
}
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
knnSearcher.runSearch(result);
for (int i = 0; i < results.length; i++) {
int nProbe = nProbes[i];
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
knnSearcher.runSearch(results[i]);
}
}
formattedResults.results.add(result);
formattedResults.results.addAll(List.of(results));
}
logger.info("Results: \n" + formattedResults);
}
Expand All @@ -218,13 +228,12 @@ public String toString() {
return "No results available.";
}

String[] indexingHeaders = { "index_type", "num_docs", "index_time(ms)", "force_merge_time(ms)", "num_segments" };

// Define column headers
String[] headers = {
String[] searchHeaders = {
"index_type",
"num_docs",
"index_time(ms)",
"force_merge_time(ms)",
"num_segments",
"n_probe",
"latency(ms)",
"net_cpu_time(ms)",
"avg_cpu_count",
Expand All @@ -233,41 +242,58 @@ public String toString() {
"visited" };

// Calculate appropriate column widths based on headers and data
int[] widths = calculateColumnWidths(headers);

StringBuilder sb = new StringBuilder();

// Format and append header
sb.append(formatRow(headers, widths));
sb.append("\n");
Results indexResult = results.get(0); // Assuming all results have the same index type and numDocs
String[] indexData = {
indexResult.indexType,
Integer.toString(indexResult.numDocs),
Long.toString(indexResult.indexTimeMS),
Long.toString(indexResult.forceMergeTimeMS),
Integer.toString(indexResult.numSegments) };

// Add separator line
for (int width : widths) {
sb.append("-".repeat(width)).append(" ");
}
sb.append("\n");
printBlock(sb, indexingHeaders, new String[][] { indexData });

String[][] searchData = new String[results.size()][];
// Format and append each row of data
for (Results result : results) {
String[] rowData = {
for (int i = 0; i < results.size(); i++) {
Results result = results.get(i);
searchData[i] = new String[] {
result.indexType,
Integer.toString(result.numDocs),
Long.toString(result.indexTimeMS),
Long.toString(result.forceMergeTimeMS),
Integer.toString(result.numSegments),
Integer.toString(result.nProbe),
String.format(Locale.ROOT, "%.2f", result.avgLatency),
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
String.format(Locale.ROOT, "%.2f", result.qps),
String.format(Locale.ROOT, "%.2f", result.avgRecall),
String.format(Locale.ROOT, "%.2f", result.averageVisited) };
sb.append(formatRow(rowData, widths));
sb.append("\n");

}

printBlock(sb, searchHeaders, searchData);

return sb.toString();
}

private void printBlock(StringBuilder sb, String[] headers, String[][] rows) {
int[] widths = calculateColumnWidths(headers, rows);
sb.append("\n");
sb.append(formatRow(headers, widths));
sb.append("\n");

// Add separator line
for (int width : widths) {
sb.append("-".repeat(width)).append(" ");
}
sb.append("\n");

for (String[] row : rows) {
sb.append(formatRow(row, widths));
sb.append("\n");
}
}

// Helper method to format a single row with proper column widths
private String formatRow(String[] values, int[] widths) {
StringBuilder row = new StringBuilder();
Expand All @@ -285,7 +311,7 @@ private String formatRow(String[] values, int[] widths) {
}

// Calculate appropriate column widths based on headers and data
private int[] calculateColumnWidths(String[] headers) {
private int[] calculateColumnWidths(String[] headers, String[]... data) {
int[] widths = new int[headers.length];

// Initialize widths with header lengths
Expand All @@ -294,20 +320,7 @@ private int[] calculateColumnWidths(String[] headers) {
}

// Update widths based on data
for (Results result : results) {
String[] values = {
result.indexType,
Integer.toString(result.numDocs),
Long.toString(result.indexTimeMS),
Long.toString(result.forceMergeTimeMS),
Integer.toString(result.numSegments),
String.format(Locale.ROOT, "%.2f", result.avgLatency),
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
String.format(Locale.ROOT, "%.2f", result.qps),
String.format(Locale.ROOT, "%.2f", result.avgRecall),
String.format(Locale.ROOT, "%.2f", result.averageVisited) };

for (String[] values : data) {
for (int i = 0; i < values.length; i++) {
widths[i] = Math.max(widths[i], values[i].length());
}
Expand All @@ -323,6 +336,7 @@ static class Results {
long indexTimeMS;
long forceMergeTimeMS;
int numSegments;
int nProbe;
double avgLatency;
double qps;
double avgRecall;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class KnnSearcher {
private final float overSamplingFactor;
private final int searchThreads;

KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs) {
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
this.docPath = cmdLineArgs.docVectors();
this.indexPath = indexPath;
this.queryPath = cmdLineArgs.queryVectors();
Expand All @@ -109,7 +109,7 @@ class KnnSearcher {
throw new IllegalArgumentException("numQueryVectors must be > 0");
}
this.efSearch = cmdLineArgs.numCandidates();
this.nProbe = cmdLineArgs.nProbe();
this.nProbe = nProbe;
this.indexType = cmdLineArgs.indexType();
this.searchThreads = cmdLineArgs.searchThreads();
}
Expand Down Expand Up @@ -206,6 +206,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
}
logger.info("checking results");
int[][] nn = getOrCalculateExactNN(offsetByteSize);
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
finalResults.avgRecall = checkResults(resultIds, nn, topK);
finalResults.qps = (1000f * numQueryVectors) / elapsed;
finalResults.avgLatency = (float) elapsed / numQueryVectors;
Expand Down