Skip to content

Commit bf4e660

Browse files
committed
refactored and simplified, added option to knn tester
1 parent 0dcb0b2 commit bf4e660

File tree

8 files changed

+144
-132
lines changed

8 files changed

+144
-132
lines changed

docs/reference/query-languages/esql/kibana/definition/functions/knn.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/knn.md

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ record CmdLineArgs(
4747
VectorSimilarityFunction vectorSpace,
4848
int quantizeBits,
4949
VectorEncoding vectorEncoding,
50-
int dimensions
50+
int dimensions,
51+
boolean earlyTermination
5152
) implements ToXContentObject {
5253

5354
static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
@@ -70,6 +71,7 @@ record CmdLineArgs(
7071
static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits");
7172
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
7273
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
74+
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
7375

7476
static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
7577
Builder builder = PARSER.apply(parser, null);
@@ -99,6 +101,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
99101
PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
100102
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
101103
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
104+
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
102105
}
103106

104107
@Override
@@ -157,6 +160,7 @@ static class Builder {
157160
private int quantizeBits = 8;
158161
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
159162
private int dimensions;
163+
private boolean earlyTermination;
160164

161165
public Builder setDocVectors(String docVectors) {
162166
this.docVectors = PathUtils.get(docVectors);
@@ -258,6 +262,11 @@ public Builder setDimensions(int dimensions) {
258262
return this;
259263
}
260264

265+
public Builder setEarlyTermination(Boolean patience) {
266+
this.earlyTermination = patience;
267+
return this;
268+
}
269+
261270
public CmdLineArgs build() {
262271
if (docVectors == null) {
263272
throw new IllegalArgumentException("Document vectors path must be provided");
@@ -285,7 +294,8 @@ public CmdLineArgs build() {
285294
vectorSpace,
286295
quantizeBits,
287296
vectorEncoding,
288-
dimensions
297+
dimensions,
298+
earlyTermination
289299
);
290300
}
291301
}

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ public static void main(String[] args) throws Exception {
202202
}
203203
if (cmdLineArgs.queryVectors() != null) {
204204
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
205-
knnSearcher.runSearch(result);
205+
knnSearcher.runSearch(result, cmdLineArgs.earlyTermination());
206206
}
207207
formattedResults.results.add(result);
208208
}

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
3434
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
3535
import org.apache.lucene.search.IndexSearcher;
36+
import org.apache.lucene.search.KnnByteVectorQuery;
37+
import org.apache.lucene.search.KnnFloatVectorQuery;
38+
import org.apache.lucene.search.PatienceKnnVectorQuery;
3639
import org.apache.lucene.search.Query;
3740
import org.apache.lucene.search.ScoreDoc;
3841
import org.apache.lucene.search.TopDocs;
@@ -113,7 +116,7 @@ class KnnSearcher {
113116
this.searchThreads = cmdLineArgs.searchThreads();
114117
}
115118

116-
void runSearch(KnnIndexTester.Results finalResults) throws IOException {
119+
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
117120
TopDocs[] results = new TopDocs[numQueryVectors];
118121
int[][] resultIds = new int[numQueryVectors][];
119122
long elapsed, totalCpuTimeMS, totalVisited = 0;
@@ -139,10 +142,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
139142
for (int i = 0; i < numQueryVectors; i++) {
140143
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
141144
targetReader.next(targetBytes);
142-
doVectorQuery(targetBytes, searcher);
145+
doVectorQuery(targetBytes, searcher, earlyTermination);
143146
} else {
144147
targetReader.next(target);
145-
doVectorQuery(target, searcher);
148+
doVectorQuery(target, searcher, earlyTermination);
146149
}
147150
}
148151
targetReader.reset();
@@ -151,10 +154,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
151154
for (int i = 0; i < numQueryVectors; i++) {
152155
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
153156
targetReader.next(targetBytes);
154-
results[i] = doVectorQuery(targetBytes, searcher);
157+
results[i] = doVectorQuery(targetBytes, searcher, earlyTermination);
155158
} else {
156159
targetReader.next(target);
157-
results[i] = doVectorQuery(target, searcher);
160+
results[i] = doVectorQuery(target, searcher, earlyTermination);
158161
}
159162
}
160163
KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
@@ -249,7 +252,7 @@ private boolean isNewer(Path path, Path... others) throws IOException {
249252
return true;
250253
}
251254

252-
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException {
255+
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
253256
Query knnQuery;
254257
if (overSamplingFactor > 1f) {
255258
throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors");
@@ -265,6 +268,9 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException
265268
null,
266269
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
267270
);
271+
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
272+
knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
273+
}
268274
}
269275
QueryProfiler profiler = new QueryProfiler();
270276
TopDocs docs = searcher.search(knnQuery, this.topK);
@@ -273,7 +279,7 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException
273279
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
274280
}
275281

276-
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException {
282+
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
277283
Query knnQuery;
278284
int topK = this.topK;
279285
if (overSamplingFactor > 1f) {
@@ -292,6 +298,9 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException
292298
null,
293299
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
294300
);
301+
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
302+
knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
303+
}
295304
}
296305
if (overSamplingFactor > 1f) {
297306
// oversample the topK results to get more candidates for the final result

0 commit comments

Comments
 (0)