Skip to content

Commit 9edfa66

Browse files
authored
Wrap ES KNN queries with PatienceKNN query (#127223)
1 parent 3153195 commit 9edfa66

File tree

14 files changed

+279
-92
lines changed

14 files changed

+279
-92
lines changed

docs/changelog/127223.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127223
2+
summary: Wrap ES KNN queries with PatienceKNN query
3+
area: Vector Search
4+
type: feature
5+
issues: []

docs/reference/elasticsearch/index-settings/index-modules.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,6 @@ $$$index-esql-stored-fields-sequential-proportion$$$
259259

260260
`index.esql.stored_fields_sequential_proportion`
261261
: Tuning parameter for deciding when {{esql}} will load [Stored fields](/reference/elasticsearch/rest-apis/retrieve-selected-fields.md#stored-fields) using a strategy tuned for loading dense sequence of documents. Allows values between 0.0 and 1.0 and defaults to 0.2. Indices with documents smaller than 10kb may see speed improvements loading `text` fields by setting this lower.
262+
263+
$$$index-dense-vector-hnsw-early-termination$$$ `index.dense_vector.hnsw_early_termination`
264+
: Whether to apply _patience_ based early termination strategy to knn queries over HNSW graphs (see [paper](https://cs.uwaterloo.ca/~jimmylin/publications/Teofili_Lin_ECIR2025.pdf)). This is only applicable to `dense_vector` fields with `hnsw`, `int8_hnsw`, `int4_hnsw` and `bbq_hnsw` index types. Defaults to `false`.

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ record CmdLineArgs(
4848
VectorSimilarityFunction vectorSpace,
4949
int quantizeBits,
5050
VectorEncoding vectorEncoding,
51-
int dimensions
51+
int dimensions,
52+
boolean earlyTermination
5253
) implements ToXContentObject {
5354

5455
static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
@@ -71,6 +72,7 @@ record CmdLineArgs(
7172
static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits");
7273
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
7374
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
75+
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
7476

7577
static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
7678
Builder builder = PARSER.apply(parser, null);
@@ -100,6 +102,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
100102
PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
101103
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
102104
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
105+
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
103106
}
104107

105108
@Override
@@ -158,6 +161,7 @@ static class Builder {
158161
private int quantizeBits = 8;
159162
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
160163
private int dimensions;
164+
private boolean earlyTermination;
161165

162166
public Builder setDocVectors(String docVectors) {
163167
this.docVectors = PathUtils.get(docVectors);
@@ -259,6 +263,11 @@ public Builder setDimensions(int dimensions) {
259263
return this;
260264
}
261265

266+
public Builder setEarlyTermination(Boolean patience) {
267+
this.earlyTermination = patience;
268+
return this;
269+
}
270+
262271
public CmdLineArgs build() {
263272
if (docVectors == null) {
264273
throw new IllegalArgumentException("Document vectors path must be provided");
@@ -288,7 +297,8 @@ public CmdLineArgs build() {
288297
vectorSpace,
289298
quantizeBits,
290299
vectorEncoding,
291-
dimensions
300+
dimensions,
301+
earlyTermination
292302
);
293303
}
294304
}

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ public static void main(String[] args) throws Exception {
211211
for (int i = 0; i < results.length; i++) {
212212
int nProbe = nProbes[i];
213213
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
214-
knnSearcher.runSearch(results[i]);
214+
knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination());
215215
}
216216
}
217217
formattedResults.results.addAll(List.of(results));

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
3434
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
3535
import org.apache.lucene.search.IndexSearcher;
36+
import org.apache.lucene.search.KnnByteVectorQuery;
37+
import org.apache.lucene.search.KnnFloatVectorQuery;
38+
import org.apache.lucene.search.PatienceKnnVectorQuery;
3639
import org.apache.lucene.search.Query;
3740
import org.apache.lucene.search.ScoreDoc;
3841
import org.apache.lucene.search.TopDocs;
@@ -114,7 +117,7 @@ class KnnSearcher {
114117
this.searchThreads = cmdLineArgs.searchThreads();
115118
}
116119

117-
void runSearch(KnnIndexTester.Results finalResults) throws IOException {
120+
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
118121
TopDocs[] results = new TopDocs[numQueryVectors];
119122
int[][] resultIds = new int[numQueryVectors][];
120123
long elapsed, totalCpuTimeMS, totalVisited = 0;
@@ -153,10 +156,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
153156
for (int i = 0; i < numQueryVectors; i++) {
154157
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
155158
targetReader.next(targetBytes);
156-
doVectorQuery(targetBytes, searcher);
159+
doVectorQuery(targetBytes, searcher, earlyTermination);
157160
} else {
158161
targetReader.next(target);
159-
doVectorQuery(target, searcher);
162+
doVectorQuery(target, searcher, earlyTermination);
160163
}
161164
}
162165
targetReader.reset();
@@ -165,10 +168,10 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
165168
for (int i = 0; i < numQueryVectors; i++) {
166169
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
167170
targetReader.next(targetBytes);
168-
results[i] = doVectorQuery(targetBytes, searcher);
171+
results[i] = doVectorQuery(targetBytes, searcher, earlyTermination);
169172
} else {
170173
targetReader.next(target);
171-
results[i] = doVectorQuery(target, searcher);
174+
results[i] = doVectorQuery(target, searcher, earlyTermination);
172175
}
173176
}
174177
KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
@@ -264,7 +267,7 @@ private boolean isNewer(Path path, Path... others) throws IOException {
264267
return true;
265268
}
266269

267-
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException {
270+
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
268271
Query knnQuery;
269272
if (overSamplingFactor > 1f) {
270273
throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors");
@@ -280,6 +283,9 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException
280283
null,
281284
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
282285
);
286+
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
287+
knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
288+
}
283289
}
284290
QueryProfiler profiler = new QueryProfiler();
285291
TopDocs docs = searcher.search(knnQuery, this.topK);
@@ -288,7 +294,7 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException
288294
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
289295
}
290296

291-
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException {
297+
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
292298
Query knnQuery;
293299
int topK = this.topK;
294300
if (overSamplingFactor > 1f) {
@@ -307,16 +313,22 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException
307313
null,
308314
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
309315
);
316+
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
317+
knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
318+
}
310319
}
311320
if (overSamplingFactor > 1f) {
312321
// oversample the topK results to get more candidates for the final result
313322
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
314323
}
315324
QueryProfiler profiler = new QueryProfiler();
316325
TopDocs docs = searcher.search(knnQuery, this.topK);
317-
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
318-
queryProfilerProvider.profile(profiler);
319-
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
326+
if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) {
327+
queryProfilerProvider.profile(profiler);
328+
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
329+
} else {
330+
return docs;
331+
}
320332
}
321333

322334
private static float checkResults(int[][] results, int[][] nn, int topK) {

server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,55 @@ public void testFilteredQueryStrategy() {
127127
});
128128
}
129129

130+
public void testHnswEarlyTerminationQuery() {
131+
float[] vector = new float[16];
132+
randomVector(vector, 25);
133+
int upperLimit = 35;
134+
var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null);
135+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), response -> {
136+
assertNotEquals(0, response.getHits().getHits().length);
137+
var profileResults = response.getProfileResults();
138+
long vectorOpsSum = profileResults.values()
139+
.stream()
140+
.mapToLong(
141+
pr -> pr.getQueryPhase()
142+
.getSearchProfileDfsPhaseResult()
143+
.getQueryProfileShardResult()
144+
.stream()
145+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
146+
.sum()
147+
)
148+
.sum();
149+
client().admin()
150+
.indices()
151+
.prepareUpdateSettings(INDEX_NAME)
152+
.setSettings(Settings.builder().put(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION.getKey(), true))
153+
.get();
154+
assertResponse(
155+
client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true),
156+
earlyTerminationResponse -> {
157+
assertNotEquals(0, earlyTerminationResponse.getHits().getHits().length);
158+
var earlyTerminationResults = earlyTerminationResponse.getProfileResults();
159+
long earlyTerminationVectorOpsSum = earlyTerminationResults.values()
160+
.stream()
161+
.mapToLong(
162+
pr -> pr.getQueryPhase()
163+
.getSearchProfileDfsPhaseResult()
164+
.getQueryProfileShardResult()
165+
.stream()
166+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
167+
.sum()
168+
)
169+
.sum();
170+
assertTrue(
171+
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]",
172+
earlyTerminationVectorOpsSum < vectorOpsSum
173+
// if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
174+
|| (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
175+
);
176+
}
177+
);
178+
});
179+
}
180+
130181
}

server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings {
159159
IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING,
160160
IndexSettings.INDEX_SEARCH_IDLE_AFTER,
161161
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC,
162+
DenseVectorFieldMapper.HNSW_EARLY_TERMINATION,
162163
IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY,
163164
IndexSettings.IGNORE_ABOVE_SETTING,
164165
FieldMapper.IGNORE_MALFORMED_SETTING,

server/src/main/java/org/elasticsearch/index/IndexSettings.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) {
916916
private volatile int maxNgramDiff;
917917
private volatile int maxShingleDiff;
918918
private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic;
919+
private volatile boolean earlyTermination;
919920
private volatile TimeValue searchIdleAfter;
920921
private volatile int maxAnalyzedOffset;
921922
private volatile boolean weightMatchesEnabled;
@@ -1113,6 +1114,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
11131114
skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING);
11141115
skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING);
11151116
hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC);
1117+
earlyTermination = scopedSettings.get(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION);
11161118
indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING);
11171119
recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings);
11181120
recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false
@@ -1227,6 +1229,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
12271229
);
12281230
scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead);
12291231
scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic);
1232+
scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION, this::setHnswEarlyTermination);
12301233
}
12311234

12321235
private void setSearchIdleAfter(TimeValue searchIdleAfter) {
@@ -1858,6 +1861,14 @@ private void setHnswFilterHeuristic(DenseVectorFieldMapper.FilterHeuristic heuri
18581861
this.hnswFilterHeuristic = heuristic;
18591862
}
18601863

1864+
public boolean getHnswEarlyTermination() {
1865+
return this.earlyTermination;
1866+
}
1867+
1868+
private void setHnswEarlyTermination(boolean earlyTermination) {
1869+
this.earlyTermination = earlyTermination;
1870+
}
1871+
18611872
public SeqNoFieldMapper.SeqNoIndexOptions seqNoIndexOptions() {
18621873
return seqNoIndexOptions;
18631874
}

0 commit comments

Comments
 (0)