Skip to content

Commit b6a53a9

Browse files
authored
Add 'profile' support for knn query on HNSW with early termination (#135342)
1 parent 0bcd6ae commit b6a53a9

File tree

10 files changed

+250
-93
lines changed

10 files changed

+250
-93
lines changed

docs/changelog/135342.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 135342
2+
summary: Add 'profile' support for knn query on HNSW with early termination
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@
3838
import org.apache.lucene.search.ConstantScoreScorer;
3939
import org.apache.lucene.search.ConstantScoreWeight;
4040
import org.apache.lucene.search.IndexSearcher;
41-
import org.apache.lucene.search.KnnByteVectorQuery;
42-
import org.apache.lucene.search.KnnFloatVectorQuery;
43-
import org.apache.lucene.search.PatienceKnnVectorQuery;
4441
import org.apache.lucene.search.Query;
4542
import org.apache.lucene.search.QueryVisitor;
4643
import org.apache.lucene.search.ScoreDoc;
@@ -401,14 +398,13 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery,
401398
topK,
402399
efSearch,
403400
filterQuery,
404-
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
401+
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
402+
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
405403
);
406-
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
407-
knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
408-
}
409404
}
410405
QueryProfiler profiler = new QueryProfiler();
411406
TopDocs docs = searcher.search(knnQuery, this.topK);
407+
assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling";
412408
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
413409
queryProfilerProvider.profile(profiler);
414410
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
@@ -432,24 +428,20 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
432428
topK,
433429
efSearch,
434430
filterQuery,
435-
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
431+
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
432+
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
436433
);
437-
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
438-
knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
439-
}
440434
}
441435
if (overSamplingFactor > 1f) {
442436
// oversample the topK results to get more candidates for the final result
443437
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
444438
}
445439
QueryProfiler profiler = new QueryProfiler();
446440
TopDocs docs = searcher.search(knnQuery, this.topK);
447-
if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) {
448-
queryProfilerProvider.profile(profiler);
449-
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
450-
} else {
451-
return docs;
452-
}
441+
assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling";
442+
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
443+
queryProfilerProvider.profile(profiler);
444+
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
453445
}
454446

455447
private static float checkResults(int[][] results, int[][] nn, int topK) {

server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,8 @@ public void testHnswEarlyTerminationQuery() {
170170
)
171171
.sum();
172172
assertTrue(
173-
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]",
174-
earlyTerminationVectorOpsSum < vectorOpsSum
175-
// if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
176-
|| (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
173+
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lte vectorOps [" + vectorOpsSum + "]",
174+
earlyTerminationVectorOpsSum <= vectorOpsSum
177175
);
178176
}
179177
);

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@
3333
import org.apache.lucene.search.BooleanClause;
3434
import org.apache.lucene.search.BooleanQuery;
3535
import org.apache.lucene.search.FieldExistsQuery;
36-
import org.apache.lucene.search.KnnByteVectorQuery;
37-
import org.apache.lucene.search.KnnFloatVectorQuery;
3836
import org.apache.lucene.search.MatchNoDocsQuery;
39-
import org.apache.lucene.search.PatienceKnnVectorQuery;
4037
import org.apache.lucene.search.Query;
4138
import org.apache.lucene.search.join.BitSetProducer;
4239
import org.apache.lucene.search.knn.KnnSearchStrategy;
@@ -2366,6 +2363,7 @@ public Query createKnnQuery(
23662363
return new MatchNoDocsQuery("No data has been indexed for field [" + name() + "]");
23672364
}
23682365
KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy();
2366+
hnswEarlyTermination &= canApplyPatienceQuery();
23692367
return switch (getElementType()) {
23702368
case BYTE -> createKnnByteQuery(
23712369
queryVector.asByteVector(),
@@ -2410,6 +2408,13 @@ private boolean isQuantized() {
24102408
return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized();
24112409
}
24122410

2411+
private boolean canApplyPatienceQuery() {
2412+
return indexOptions instanceof HnswIndexOptions
2413+
|| indexOptions instanceof Int8HnswIndexOptions
2414+
|| indexOptions instanceof Int4HnswIndexOptions
2415+
|| indexOptions instanceof BBQHnswIndexOptions;
2416+
}
2417+
24132418
private Query createKnnBitQuery(
24142419
byte[] queryVector,
24152420
int k,
@@ -2433,11 +2438,17 @@ private Query createKnnBitQuery(
24332438
.build();
24342439
} else {
24352440
knnQuery = parentFilter != null
2436-
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
2437-
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
2438-
if (hnswEarlyTermination) {
2439-
knnQuery = maybeWrapPatience(knnQuery);
2440-
}
2441+
? new ESDiversifyingChildrenByteKnnVectorQuery(
2442+
name(),
2443+
queryVector,
2444+
filter,
2445+
k,
2446+
numCands,
2447+
parentFilter,
2448+
searchStrategy,
2449+
hnswEarlyTermination
2450+
)
2451+
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
24412452
}
24422453
if (similarityThreshold != null) {
24432454
knnQuery = new VectorSimilarityQuery(
@@ -2477,11 +2488,17 @@ private Query createKnnByteQuery(
24772488
.build();
24782489
} else {
24792490
knnQuery = parentFilter != null
2480-
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
2481-
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
2482-
if (hnswEarlyTermination) {
2483-
knnQuery = maybeWrapPatience(knnQuery);
2484-
}
2491+
? new ESDiversifyingChildrenByteKnnVectorQuery(
2492+
name(),
2493+
queryVector,
2494+
filter,
2495+
k,
2496+
numCands,
2497+
parentFilter,
2498+
searchStrategy,
2499+
hnswEarlyTermination
2500+
)
2501+
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
24852502
}
24862503
if (similarityThreshold != null) {
24872504
knnQuery = new VectorSimilarityQuery(
@@ -2493,23 +2510,6 @@ private Query createKnnByteQuery(
24932510
return knnQuery;
24942511
}
24952512

2496-
private Query maybeWrapPatience(Query knnQuery) {
2497-
Query finalQuery = knnQuery;
2498-
if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) {
2499-
finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery);
2500-
} else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) {
2501-
finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery);
2502-
}
2503-
return finalQuery;
2504-
}
2505-
2506-
private boolean canApplyPatienceQuery() {
2507-
return indexOptions instanceof HnswIndexOptions
2508-
|| indexOptions instanceof Int8HnswIndexOptions
2509-
|| indexOptions instanceof Int4HnswIndexOptions
2510-
|| indexOptions instanceof BBQHnswIndexOptions;
2511-
}
2512-
25132513
private Query createKnnFloatQuery(
25142514
float[] queryVector,
25152515
int k,
@@ -2586,10 +2586,7 @@ private Query createKnnFloatQuery(
25862586
parentFilter,
25872587
knnSearchStrategy
25882588
)
2589-
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
2590-
if (hnswEarlyTermination) {
2591-
knnQuery = maybeWrapPatience(knnQuery);
2592-
}
2589+
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy, hnswEarlyTermination);
25932590
}
25942591
if (rescore) {
25952592
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(

server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import org.apache.lucene.search.IndexSearcher;
1213
import org.apache.lucene.search.Query;
1314
import org.apache.lucene.search.TopDocs;
1415
import org.apache.lucene.search.join.BitSetProducer;
1516
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
17+
import org.apache.lucene.search.knn.KnnCollectorManager;
1618
import org.apache.lucene.search.knn.KnnSearchStrategy;
1719
import org.elasticsearch.search.profile.query.QueryProfiler;
1820

1921
public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
2022
private final int kParam;
2123
private long vectorOpsCount;
24+
private final boolean earlyTermination;
2225

2326
public ESDiversifyingChildrenByteKnnVectorQuery(
2427
String field,
@@ -28,9 +31,23 @@ public ESDiversifyingChildrenByteKnnVectorQuery(
2831
int numCands,
2932
BitSetProducer parentsFilter,
3033
KnnSearchStrategy strategy
34+
) {
35+
this(field, query, childFilter, k, numCands, parentsFilter, strategy, false);
36+
}
37+
38+
public ESDiversifyingChildrenByteKnnVectorQuery(
39+
String field,
40+
byte[] query,
41+
Query childFilter,
42+
int k,
43+
int numCands,
44+
BitSetProducer parentsFilter,
45+
KnnSearchStrategy strategy,
46+
boolean earlyTermination
3147
) {
3248
super(field, query, childFilter, numCands, parentsFilter, strategy);
3349
this.kParam = k;
50+
this.earlyTermination = earlyTermination;
3451
}
3552

3653
@Override
@@ -48,4 +65,10 @@ public void profile(QueryProfiler queryProfiler) {
4865
public KnnSearchStrategy getStrategy() {
4966
return searchStrategy;
5067
}
68+
69+
@Override
70+
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
71+
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
72+
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
73+
}
5174
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,35 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import org.apache.lucene.search.IndexSearcher;
1213
import org.apache.lucene.search.KnnByteVectorQuery;
1314
import org.apache.lucene.search.Query;
1415
import org.apache.lucene.search.TopDocs;
16+
import org.apache.lucene.search.knn.KnnCollectorManager;
1517
import org.apache.lucene.search.knn.KnnSearchStrategy;
1618
import org.elasticsearch.search.profile.query.QueryProfiler;
1719

1820
public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
1921
private final int kParam;
2022
private long vectorOpsCount;
23+
private final boolean earlyTermination;
2124

2225
public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
26+
this(field, target, k, numCands, filter, strategy, false);
27+
}
28+
29+
public ESKnnByteVectorQuery(
30+
String field,
31+
byte[] target,
32+
int k,
33+
int numCands,
34+
Query filter,
35+
KnnSearchStrategy strategy,
36+
boolean earlyTermination
37+
) {
2338
super(field, target, numCands, filter, strategy);
2439
this.kParam = k;
40+
this.earlyTermination = earlyTermination;
2541
}
2642

2743
@Override
@@ -44,4 +60,10 @@ public Integer kParam() {
4460
public KnnSearchStrategy getStrategy() {
4561
return searchStrategy;
4662
}
63+
64+
@Override
65+
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
66+
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
67+
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
68+
}
4769
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,35 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import org.apache.lucene.search.IndexSearcher;
1213
import org.apache.lucene.search.KnnFloatVectorQuery;
1314
import org.apache.lucene.search.Query;
1415
import org.apache.lucene.search.TopDocs;
16+
import org.apache.lucene.search.knn.KnnCollectorManager;
1517
import org.apache.lucene.search.knn.KnnSearchStrategy;
1618
import org.elasticsearch.search.profile.query.QueryProfiler;
1719

1820
public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
1921
private final int kParam;
2022
private long vectorOpsCount;
23+
private final boolean earlyTermination;
2124

2225
public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
26+
this(field, target, k, numCands, filter, strategy, false);
27+
}
28+
29+
public ESKnnFloatVectorQuery(
30+
String field,
31+
float[] target,
32+
int k,
33+
int numCands,
34+
Query filter,
35+
KnnSearchStrategy strategy,
36+
boolean earlyTermination
37+
) {
2338
super(field, target, numCands, filter, strategy);
2439
this.kParam = k;
40+
this.earlyTermination = earlyTermination;
2541
}
2642

2743
@Override
@@ -44,4 +60,10 @@ public int kParam() {
4460
public KnnSearchStrategy getStrategy() {
4561
return searchStrategy;
4662
}
63+
64+
@Override
65+
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
66+
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
67+
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
68+
}
4769
}

0 commit comments

Comments
 (0)