Skip to content

Commit ad5df9d

Browse files
authored
Enable caching of all filters in knn queries (elastic#134458)
* Enable caching of all filters in `knn` queries This change makes all filters in the `knn` query eligible for query caching. By default, Lucene considers some simple filters (e.g., term queries) too cheap to cache. In the context of vector search, these filters are eagerly materialized as bitsets, which makes them significantly more expensive to evaluate on every request. Forcing them to be cacheable avoids repeated recomputation. This is a stop-gap change to support simple use cases such as a single term query used as a filter in `knn`. The long-term solution is to move this decision logic into the Lucene `knn` codec itself, but that will require more time. ### Benchmark Results Dataset: **20M 128D vectors**, term filter matching \~80% of documents. **With this change:** ``` Precision QPS P50 (ms) P95 (ms) 0.91 632.8 5.763 9.900 ``` **Without this change:** ``` Precision QPS P50 (ms) P95 (ms) 0.91 68.2 82.52 193.92 ```
1 parent 7e5a5bd commit ad5df9d

File tree

6 files changed

+248
-12
lines changed

6 files changed

+248
-12
lines changed

docs/changelog/134458.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 134458
2+
summary: Enable caching of all filters in `knn` queries
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ public final int hashCode() {
12611261
*
12621262
* @return {@code true} if the vector search is flat (exhaustive), {@code false} if it uses ANN structures
12631263
*/
1264-
abstract boolean isFlat();
1264+
public abstract boolean isFlat();
12651265
}
12661266

12671267
abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions {
@@ -1639,7 +1639,7 @@ int doHashCode() {
16391639
}
16401640

16411641
@Override
1642-
boolean isFlat() {
1642+
public boolean isFlat() {
16431643
return true;
16441644
}
16451645

@@ -1694,7 +1694,7 @@ public int doHashCode() {
16941694
}
16951695

16961696
@Override
1697-
boolean isFlat() {
1697+
public boolean isFlat() {
16981698
return true;
16991699
}
17001700
}
@@ -1748,7 +1748,7 @@ public int doHashCode() {
17481748
}
17491749

17501750
@Override
1751-
boolean isFlat() {
1751+
public boolean isFlat() {
17521752
return false;
17531753
}
17541754

@@ -1826,7 +1826,7 @@ public int doHashCode() {
18261826
}
18271827

18281828
@Override
1829-
boolean isFlat() {
1829+
public boolean isFlat() {
18301830
return true;
18311831
}
18321832

@@ -1900,7 +1900,7 @@ public int doHashCode() {
19001900
}
19011901

19021902
@Override
1903-
boolean isFlat() {
1903+
public boolean isFlat() {
19041904
return false;
19051905
}
19061906

@@ -1996,7 +1996,7 @@ public int doHashCode() {
19961996
}
19971997

19981998
@Override
1999-
boolean isFlat() {
1999+
public boolean isFlat() {
20002000
return false;
20012001
}
20022002

@@ -2040,7 +2040,7 @@ int doHashCode() {
20402040
}
20412041

20422042
@Override
2043-
boolean isFlat() {
2043+
public boolean isFlat() {
20442044
return false;
20452045
}
20462046

@@ -2100,7 +2100,7 @@ int doHashCode() {
21002100
}
21012101

21022102
@Override
2103-
boolean isFlat() {
2103+
public boolean isFlat() {
21042104
return true;
21052105
}
21062106

@@ -2163,7 +2163,7 @@ int doHashCode() {
21632163
}
21642164

21652165
@Override
2166-
boolean isFlat() {
2166+
public boolean isFlat() {
21672167
return false;
21682168
}
21692169

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.search.BooleanQuery;
13+
import org.apache.lucene.search.FilterWeight;
14+
import org.apache.lucene.search.IndexSearcher;
15+
import org.apache.lucene.search.MatchAllDocsQuery;
16+
import org.apache.lucene.search.MatchNoDocsQuery;
17+
import org.apache.lucene.search.Query;
18+
import org.apache.lucene.search.QueryVisitor;
19+
import org.apache.lucene.search.ScoreMode;
20+
import org.apache.lucene.search.UsageTrackingQueryCachingPolicy;
21+
import org.apache.lucene.search.Weight;
22+
23+
import java.io.IOException;
24+
import java.util.Objects;
25+
26+
/**
27+
* A query wrapper that ensures the inner query is always eligible for query caching.
28+
* <p>
29+
* Lucene uses heuristics to determine whether a query should be cached ({@link UsageTrackingQueryCachingPolicy}),
30+
* and some queries may be skipped if they are considered too cheap or otherwise uninteresting for caching.
31+
* Wrapping a query in {@link CachingEnableFilterQuery} guarantees that it will be treated as
32+
* cacheable by the query cache.
33+
* </p>
34+
*
35+
* <p>
36+
* This wrapper does not alter the scoring or filtering semantics of the inner query.
37+
* It only changes how the query cache perceives it, by making it always considered
38+
* interesting enough to cache.
39+
* </p>
40+
*
41+
* <p>
42+
* This is particularly useful in cases where the filter is always entirely consumed,
43+
* such as filtered vector search, where the filter is transformed into a bitset eagerly.
44+
* In these scenarios, caching the filter query can significantly improve performance and avoid recomputation.
45+
* </p>
46+
*
47+
* <h2>Example usage:</h2>
48+
* <pre>{@code
49+
* Query inner = new TermQuery(new Term("field", "value"));
50+
* Query cacheable = new CacheWrapperQuery(inner);
51+
* TopDocs results = searcher.search(cacheable, 10);
52+
* }</pre>
53+
*/
54+
public class CachingEnableFilterQuery extends Query {
55+
private final Query in;
56+
57+
public CachingEnableFilterQuery(Query in) {
58+
this.in = in;
59+
}
60+
61+
@Override
62+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
63+
var inWeight = in.createWeight(searcher, scoreMode, boost);
64+
return new FilterWeight(this, inWeight) {
65+
};
66+
}
67+
68+
@Override
69+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
70+
var rewrite = in.rewrite(indexSearcher);
71+
if (rewrite instanceof MatchNoDocsQuery || rewrite instanceof MatchAllDocsQuery || rewrite instanceof BooleanQuery) {
72+
// If the query matches all documents, no documents, or rewrites into a compound query
73+
// that is already eligible for caching, we can safely remove this wrapper.
74+
return rewrite;
75+
}
76+
return rewrite != in ? new CachingEnableFilterQuery(rewrite) : this;
77+
}
78+
79+
@Override
80+
public String toString(String field) {
81+
return in.toString(field);
82+
}
83+
84+
@Override
85+
public void visit(QueryVisitor visitor) {
86+
in.visit(visitor);
87+
}
88+
89+
@Override
90+
public boolean equals(Object obj) {
91+
if (obj == this) {
92+
return true;
93+
}
94+
if (obj == null || obj.getClass() != getClass()) {
95+
return false;
96+
}
97+
return in.equals(((CachingEnableFilterQuery) obj).in);
98+
}
99+
100+
@Override
101+
public int hashCode() {
102+
return Objects.hash(getClass(), in.hashCode());
103+
}
104+
}

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,14 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
645645
DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic();
646646
boolean hnswEarlyTermination = context.getIndexSettings().getHnswEarlyTermination();
647647
Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();
648+
if (filterQuery != null && (vectorFieldType.getIndexOptions() == null || vectorFieldType.getIndexOptions().isFlat() == false)) {
649+
// Force the filter to be cacheable because it will be eagerly transformed into a bitset.
650+
// Simple filters (e.g., term queries) are normally considered too cheap to cache by the
651+
// default strategy, but once materialized as a bitset on every execution they become
652+
// significantly more expensive, making caching essential.
653+
filterQuery = new CachingEnableFilterQuery(filterQuery);
654+
}
655+
648656
return vectorFieldType.createKnnQuery(
649657
queryVector,
650658
k,

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
225225
}
226226
BooleanQuery booleanQuery = builder.build();
227227
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
228+
Query approxFilterQuery = filterQuery != null ? new CachingEnableFilterQuery(filterQuery) : null;
228229
Integer numCands = queryBuilder.numCands();
229230
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
230231
float oversample = queryBuilder.rescoreVectorBuilder().oversample();
@@ -244,15 +245,15 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
244245
queryBuilder.queryVector().asByteVector(),
245246
k,
246247
numCands,
247-
filterQuery,
248+
approxFilterQuery,
248249
expectedStrategy
249250
);
250251
case FLOAT -> new ESKnnFloatVectorQuery(
251252
VECTOR_FIELD,
252253
queryBuilder.queryVector().asFloatVector(),
253254
k,
254255
numCands,
255-
filterQuery,
256+
approxFilterQuery,
256257
expectedStrategy
257258
);
258259
};
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.document.Document;
13+
import org.apache.lucene.document.Field;
14+
import org.apache.lucene.index.IndexReader;
15+
import org.apache.lucene.index.Term;
16+
import org.apache.lucene.search.BooleanQuery;
17+
import org.apache.lucene.search.IndexSearcher;
18+
import org.apache.lucene.search.Query;
19+
import org.apache.lucene.search.Sort;
20+
import org.apache.lucene.search.TermQuery;
21+
import org.apache.lucene.search.UsageTrackingQueryCachingPolicy;
22+
import org.apache.lucene.store.Directory;
23+
import org.apache.lucene.tests.index.RandomIndexWriter;
24+
import org.apache.lucene.tests.search.QueryUtils;
25+
import org.elasticsearch.test.ESTestCase;
26+
27+
import java.io.IOException;
28+
29+
import static org.apache.lucene.search.BooleanClause.Occur.FILTER;
30+
import static org.hamcrest.Matchers.equalTo;
31+
import static org.hamcrest.Matchers.instanceOf;
32+
33+
public class CachingEnableFilterQueryTests extends ESTestCase {
34+
public void testEquals() {
35+
Query c1 = new CachingEnableFilterQuery(new TermQuery(new Term("foo", "bar")));
36+
Query c2 = new CachingEnableFilterQuery(new TermQuery(new Term("foo", "bar")));
37+
QueryUtils.checkEqual(c1, c2);
38+
39+
c1 = new CachingEnableFilterQuery(new TermQuery(new Term("foo", "bar")));
40+
c2 = new CachingEnableFilterQuery(new TermQuery(new Term("foo", "baz")));
41+
QueryUtils.checkUnequal(c1, c2);
42+
43+
c1 = new TermQuery(new Term("foo", "bar"));
44+
c2 = new CachingEnableFilterQuery(new TermQuery(new Term("foo", "baz")));
45+
QueryUtils.checkUnequal(c1, c2);
46+
}
47+
48+
public void testTermQuery() throws IOException {
49+
try (Directory dir = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
50+
for (int i = 0; i < 10; i++) {
51+
Document doc = new Document();
52+
if (i % 2 == 0) {
53+
doc.add(newStringField("foo", "bar", Field.Store.YES));
54+
}
55+
iw.addDocument(doc);
56+
}
57+
58+
try (IndexReader reader = iw.getReader()) {
59+
assertTrue(reader.leaves().size() == 1 && reader.hasDeletions() == false);
60+
IndexSearcher searcher = newSearcher(reader);
61+
var termQuery = new TermQuery(new Term("foo", "bar"));
62+
Query query = new CachingEnableFilterQuery(termQuery);
63+
assertThat(searcher.rewrite(query), instanceOf(CachingEnableFilterQuery.class));
64+
assertEquals(5, searcher.count(query));
65+
66+
var cachingPolicy = new UsageTrackingQueryCachingPolicy();
67+
searcher.setQueryCachingPolicy(cachingPolicy);
68+
var rewritten = searcher.rewrite(query);
69+
for (int i = 0; i < 5; i++) {
70+
assertThat(searcher.search(rewritten, 10, Sort.INDEXORDER).totalHits.value(), equalTo(5L));
71+
}
72+
assertTrue(cachingPolicy.shouldCache(rewritten));
73+
74+
for (int i = 0; i < 10; i++) {
75+
assertThat(searcher.search(termQuery, 10, Sort.INDEXORDER).totalHits.value(), equalTo(5L));
76+
}
77+
assertFalse(cachingPolicy.shouldCache(termQuery));
78+
}
79+
}
80+
}
81+
82+
public void testBooleanQuery() throws IOException {
83+
try (Directory dir = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
84+
for (int i = 0; i < 10; i++) {
85+
Document doc = new Document();
86+
if (i % 2 == 0) {
87+
doc.add(newStringField("f1", "bar", Field.Store.YES));
88+
}
89+
doc.add(newStringField("f2", "bar", Field.Store.YES));
90+
iw.addDocument(doc);
91+
}
92+
93+
try (IndexReader reader = iw.getReader()) {
94+
assertTrue(reader.leaves().size() == 1 && reader.hasDeletions() == false);
95+
IndexSearcher searcher = newSearcher(reader);
96+
var filter = new BooleanQuery.Builder().add(new TermQuery(new Term("f1", "bar")), FILTER)
97+
.add(new TermQuery(new Term("f2", "bar")), FILTER)
98+
.build();
99+
Query query = new CachingEnableFilterQuery(filter);
100+
assertThat(searcher.rewrite(query), instanceOf(BooleanQuery.class));
101+
assertEquals(5, searcher.count(query));
102+
103+
filter = new BooleanQuery.Builder().add(new TermQuery(new Term("f1", "bar")), FILTER).build();
104+
query = new CachingEnableFilterQuery(filter);
105+
assertThat(searcher.rewrite(query), instanceOf(CachingEnableFilterQuery.class));
106+
assertEquals(5, searcher.count(query));
107+
108+
var cachingPolicy = new UsageTrackingQueryCachingPolicy();
109+
searcher.setQueryCachingPolicy(cachingPolicy);
110+
var rewritten = searcher.rewrite(query);
111+
for (int i = 0; i < 5; i++) {
112+
assertThat(searcher.search(rewritten, 10, Sort.INDEXORDER).totalHits.value(), equalTo(5L));
113+
}
114+
assertTrue(cachingPolicy.shouldCache(rewritten));
115+
}
116+
}
117+
}
118+
}

0 commit comments

Comments
 (0)