Skip to content

Commit a9ef822

Browse files
authored
Fix filtered knn vector search when query timeouts are enabled (#129440)
Turns out when we have query cancellation checks turned on, we wrap the filter bitset, meaning we cannot actually see that the inner Bits is a bitset. This is important for the hnsw knn format readers, see: https://github.com/apache/lucene/blob/1584c05b27ac31fbccb0ab328bf9f8eb6a7de414/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java#L335 Related: #126876
1 parent fc77640 commit a9ef822

File tree

4 files changed

+228
-3
lines changed

4 files changed

+228
-3
lines changed

docs/changelog/129440.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129440
2+
summary: Fix filtered knn vector search when query timeouts are enabled
3+
area: Vector Search
4+
type: bug
5+
issues: []
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.query;
11+
12+
import org.elasticsearch.cluster.metadata.IndexMetadata;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
15+
import org.elasticsearch.index.query.QueryBuilders;
16+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
17+
import org.elasticsearch.test.ESIntegTestCase;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xcontent.XContentFactory;
20+
import org.junit.Before;
21+
22+
import java.io.IOException;
23+
import java.util.List;
24+
25+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
26+
27+
public class VectorIT extends ESIntegTestCase {
28+
29+
private static final String INDEX_NAME = "test";
30+
private static final String VECTOR_FIELD = "vector";
31+
private static final String NUM_ID_FIELD = "num_id";
32+
33+
private static void randomVector(float[] vector) {
34+
for (int i = 0; i < vector.length; i++) {
35+
vector[i] = randomFloat();
36+
}
37+
}
38+
39+
@Before
40+
public void setup() throws IOException {
41+
XContentBuilder mapping = XContentFactory.jsonBuilder()
42+
.startObject()
43+
.startObject("properties")
44+
.startObject(VECTOR_FIELD)
45+
.field("type", "dense_vector")
46+
.startObject("index_options")
47+
.field("type", "hnsw")
48+
.endObject()
49+
.endObject()
50+
.startObject(NUM_ID_FIELD)
51+
.field("type", "long")
52+
.endObject()
53+
.endObject()
54+
.endObject();
55+
56+
Settings settings = Settings.builder()
57+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
58+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
59+
.build();
60+
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
61+
ensureGreen(INDEX_NAME);
62+
for (int i = 0; i < 150; i++) {
63+
float[] vector = new float[8];
64+
randomVector(vector);
65+
prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource(VECTOR_FIELD, vector, NUM_ID_FIELD, i).get();
66+
}
67+
forceMerge(true);
68+
refresh(INDEX_NAME);
69+
}
70+
71+
public void testFilteredQueryStrategy() {
72+
float[] vector = new float[8];
73+
randomVector(vector);
74+
var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery(
75+
QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(30)
76+
);
77+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> {
78+
assertNotEquals(0, acornResponse.getHits().getHits().length);
79+
var profileResults = acornResponse.getProfileResults();
80+
long vectorOpsSum = profileResults.values()
81+
.stream()
82+
.mapToLong(
83+
pr -> pr.getQueryPhase()
84+
.getSearchProfileDfsPhaseResult()
85+
.getQueryProfileShardResult()
86+
.stream()
87+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
88+
.sum()
89+
)
90+
.sum();
91+
client().admin()
92+
.indices()
93+
.prepareUpdateSettings(INDEX_NAME)
94+
.setSettings(
95+
Settings.builder()
96+
.put(
97+
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC.getKey(),
98+
DenseVectorFieldMapper.FilterHeuristic.FANOUT.toString()
99+
)
100+
)
101+
.get();
102+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), fanoutResponse -> {
103+
assertNotEquals(0, fanoutResponse.getHits().getHits().length);
104+
var fanoutProfileResults = fanoutResponse.getProfileResults();
105+
long fanoutVectorOpsSum = fanoutProfileResults.values()
106+
.stream()
107+
.mapToLong(
108+
pr -> pr.getQueryPhase()
109+
.getSearchProfileDfsPhaseResult()
110+
.getQueryProfileShardResult()
111+
.stream()
112+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
113+
.sum()
114+
)
115+
.sum();
116+
assertTrue(
117+
"fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]",
118+
fanoutVectorOpsSum > vectorOpsSum
119+
);
120+
});
121+
});
122+
}
123+
124+
}

server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.lucene.search.KnnCollector;
2727
import org.apache.lucene.search.VectorScorer;
2828
import org.apache.lucene.search.suggest.document.CompletionTerms;
29+
import org.apache.lucene.util.BitSet;
2930
import org.apache.lucene.util.Bits;
3031
import org.apache.lucene.util.BytesRef;
3132
import org.apache.lucene.util.automaton.CompiledAutomaton;
@@ -145,7 +146,7 @@ public void searchNearestVectors(String field, byte[] target, KnnCollector colle
145146
in.searchNearestVectors(field, target, collector, acceptDocs);
146147
return;
147148
}
148-
in.searchNearestVectors(field, target, collector, new TimeOutCheckingBits(acceptDocs));
149+
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
149150
}
150151

151152
@Override
@@ -163,15 +164,106 @@ public void searchNearestVectors(String field, float[] target, KnnCollector coll
163164
in.searchNearestVectors(field, target, collector, acceptDocs);
164165
return;
165166
}
166-
in.searchNearestVectors(field, target, collector, new TimeOutCheckingBits(acceptDocs));
167+
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
168+
}
169+
170+
private Bits createTimeOutCheckingBits(Bits acceptDocs) {
171+
if (acceptDocs == null || acceptDocs instanceof BitSet) {
172+
return new TimeOutCheckingBitSet((BitSet) acceptDocs);
173+
}
174+
return new TimeOutCheckingBits(acceptDocs);
175+
}
176+
177+
private class TimeOutCheckingBitSet extends BitSet {
178+
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
179+
private int calls;
180+
private final BitSet inner;
181+
private final int maxDoc;
182+
183+
private TimeOutCheckingBitSet(BitSet inner) {
184+
this.inner = inner;
185+
this.maxDoc = maxDoc();
186+
}
187+
188+
@Override
189+
public void set(int i) {
190+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
191+
}
192+
193+
@Override
194+
public boolean getAndSet(int i) {
195+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
196+
}
197+
198+
@Override
199+
public void clear(int i) {
200+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
201+
}
202+
203+
@Override
204+
public void clear(int startIndex, int endIndex) {
205+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
206+
}
207+
208+
@Override
209+
public int cardinality() {
210+
if (inner == null) {
211+
return maxDoc;
212+
}
213+
return inner.cardinality();
214+
}
215+
216+
@Override
217+
public int approximateCardinality() {
218+
if (inner == null) {
219+
return maxDoc;
220+
}
221+
return inner.approximateCardinality();
222+
}
223+
224+
@Override
225+
public int prevSetBit(int index) {
226+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
227+
}
228+
229+
@Override
230+
public int nextSetBit(int start, int end) {
231+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
232+
}
233+
234+
@Override
235+
public long ramBytesUsed() {
236+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
237+
}
238+
239+
@Override
240+
public boolean get(int index) {
241+
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
242+
queryCancellation.checkCancelled();
243+
}
244+
if (inner == null) {
245+
// if acceptDocs is null, we assume all docs are accepted
246+
return index >= 0 && index < maxDoc;
247+
}
248+
return inner.get(index);
249+
}
250+
251+
@Override
252+
public int length() {
253+
if (inner == null) {
254+
// if acceptDocs is null, we assume all docs are accepted
255+
return maxDoc;
256+
}
257+
return 0;
258+
}
167259
}
168260

169261
private class TimeOutCheckingBits implements Bits {
170262
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
171263
private final Bits updatedAcceptDocs;
172264
private int calls;
173265

174-
TimeOutCheckingBits(Bits acceptDocs) {
266+
private TimeOutCheckingBits(Bits acceptDocs) {
175267
// when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
176268
// match all docs to allow timeout checking.
177269
this.updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;

server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,8 @@ public int hashCode() {
137137
public String toString() {
138138
return Strings.toString(this);
139139
}
140+
141+
public Long getVectorOperationsCount() {
142+
return vectorOperationsCount;
143+
}
140144
}

0 commit comments

Comments
 (0)