Skip to content

Commit 644bbc5

Browse files
Merge branch 'main' into fix/es-11525
2 parents f8627ba + e437163 commit 644bbc5

File tree

24 files changed

+1691
-53
lines changed

24 files changed

+1691
-53
lines changed

docs/changelog/129200.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129200
2+
summary: Simplified Linear Retriever
3+
area: Search
4+
type: enhancement
5+
issues: []

docs/changelog/129440.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129440
2+
summary: Fix filtered knn vector search when query timeouts are enabled
3+
area: Vector Search
4+
type: bug
5+
issues: []

muted-tests.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,6 @@ tests:
271271
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
272272
method: test {p0=transform/transforms_start_stop/Test schedule_now on an already started transform}
273273
issue: https://github.com/elastic/elasticsearch/issues/120720
274-
- class: org.elasticsearch.index.engine.ThreadPoolMergeExecutorServiceTests
275-
method: testIORateIsAdjustedForRunningMergeTasks
276-
issue: https://github.com/elastic/elasticsearch/issues/125842
277274
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
278275
method: test {p0=transform/transforms_start_stop/Verify start transform creates destination index with appropriate mapping}
279276
issue: https://github.com/elastic/elasticsearch/issues/125854
@@ -301,9 +298,6 @@ tests:
301298
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
302299
method: testSchedulerCloseWaitsForRunningMerge
303300
issue: https://github.com/elastic/elasticsearch/issues/125236
304-
- class: org.elasticsearch.xpack.security.SecurityRolesMultiProjectIT
305-
method: testUpdatingFileBasedRoleAffectsAllProjects
306-
issue: https://github.com/elastic/elasticsearch/issues/126223
307301
- class: org.elasticsearch.packaging.test.DockerTests
308302
method: test020PluginsListWithNoPlugins
309303
issue: https://github.com/elastic/elasticsearch/issues/126232
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.query;
11+
12+
import org.elasticsearch.cluster.metadata.IndexMetadata;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
15+
import org.elasticsearch.index.query.QueryBuilders;
16+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
17+
import org.elasticsearch.test.ESIntegTestCase;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xcontent.XContentFactory;
20+
import org.junit.Before;
21+
22+
import java.io.IOException;
23+
import java.util.List;
24+
25+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
26+
27+
public class VectorIT extends ESIntegTestCase {
28+
29+
private static final String INDEX_NAME = "test";
30+
private static final String VECTOR_FIELD = "vector";
31+
private static final String NUM_ID_FIELD = "num_id";
32+
33+
private static void randomVector(float[] vector) {
34+
for (int i = 0; i < vector.length; i++) {
35+
vector[i] = randomFloat();
36+
}
37+
}
38+
39+
@Before
40+
public void setup() throws IOException {
41+
XContentBuilder mapping = XContentFactory.jsonBuilder()
42+
.startObject()
43+
.startObject("properties")
44+
.startObject(VECTOR_FIELD)
45+
.field("type", "dense_vector")
46+
.startObject("index_options")
47+
.field("type", "hnsw")
48+
.endObject()
49+
.endObject()
50+
.startObject(NUM_ID_FIELD)
51+
.field("type", "long")
52+
.endObject()
53+
.endObject()
54+
.endObject();
55+
56+
Settings settings = Settings.builder()
57+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
58+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
59+
.build();
60+
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
61+
ensureGreen(INDEX_NAME);
62+
for (int i = 0; i < 150; i++) {
63+
float[] vector = new float[8];
64+
randomVector(vector);
65+
prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource(VECTOR_FIELD, vector, NUM_ID_FIELD, i).get();
66+
}
67+
forceMerge(true);
68+
refresh(INDEX_NAME);
69+
}
70+
71+
public void testFilteredQueryStrategy() {
72+
float[] vector = new float[8];
73+
randomVector(vector);
74+
var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null).addFilterQuery(
75+
QueryBuilders.rangeQuery(NUM_ID_FIELD).lte(30)
76+
);
77+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), acornResponse -> {
78+
assertNotEquals(0, acornResponse.getHits().getHits().length);
79+
var profileResults = acornResponse.getProfileResults();
80+
long vectorOpsSum = profileResults.values()
81+
.stream()
82+
.mapToLong(
83+
pr -> pr.getQueryPhase()
84+
.getSearchProfileDfsPhaseResult()
85+
.getQueryProfileShardResult()
86+
.stream()
87+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
88+
.sum()
89+
)
90+
.sum();
91+
client().admin()
92+
.indices()
93+
.prepareUpdateSettings(INDEX_NAME)
94+
.setSettings(
95+
Settings.builder()
96+
.put(
97+
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC.getKey(),
98+
DenseVectorFieldMapper.FilterHeuristic.FANOUT.toString()
99+
)
100+
)
101+
.get();
102+
assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), fanoutResponse -> {
103+
assertNotEquals(0, fanoutResponse.getHits().getHits().length);
104+
var fanoutProfileResults = fanoutResponse.getProfileResults();
105+
long fanoutVectorOpsSum = fanoutProfileResults.values()
106+
.stream()
107+
.mapToLong(
108+
pr -> pr.getQueryPhase()
109+
.getSearchProfileDfsPhaseResult()
110+
.getQueryProfileShardResult()
111+
.stream()
112+
.mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
113+
.sum()
114+
)
115+
.sum();
116+
assertTrue(
117+
"fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]",
118+
fanoutVectorOpsSum > vectorOpsSum
119+
);
120+
});
121+
});
122+
}
123+
124+
}

server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.lucene.search.KnnCollector;
2727
import org.apache.lucene.search.VectorScorer;
2828
import org.apache.lucene.search.suggest.document.CompletionTerms;
29+
import org.apache.lucene.util.BitSet;
2930
import org.apache.lucene.util.Bits;
3031
import org.apache.lucene.util.BytesRef;
3132
import org.apache.lucene.util.automaton.CompiledAutomaton;
@@ -145,7 +146,7 @@ public void searchNearestVectors(String field, byte[] target, KnnCollector colle
145146
in.searchNearestVectors(field, target, collector, acceptDocs);
146147
return;
147148
}
148-
in.searchNearestVectors(field, target, collector, new TimeOutCheckingBits(acceptDocs));
149+
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
149150
}
150151

151152
@Override
@@ -163,15 +164,106 @@ public void searchNearestVectors(String field, float[] target, KnnCollector coll
163164
in.searchNearestVectors(field, target, collector, acceptDocs);
164165
return;
165166
}
166-
in.searchNearestVectors(field, target, collector, new TimeOutCheckingBits(acceptDocs));
167+
in.searchNearestVectors(field, target, collector, createTimeOutCheckingBits(acceptDocs));
168+
}
169+
170+
private Bits createTimeOutCheckingBits(Bits acceptDocs) {
171+
if (acceptDocs == null || acceptDocs instanceof BitSet) {
172+
return new TimeOutCheckingBitSet((BitSet) acceptDocs);
173+
}
174+
return new TimeOutCheckingBits(acceptDocs);
175+
}
176+
177+
private class TimeOutCheckingBitSet extends BitSet {
178+
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
179+
private int calls;
180+
private final BitSet inner;
181+
private final int maxDoc;
182+
183+
private TimeOutCheckingBitSet(BitSet inner) {
184+
this.inner = inner;
185+
this.maxDoc = maxDoc();
186+
}
187+
188+
@Override
189+
public void set(int i) {
190+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
191+
}
192+
193+
@Override
194+
public boolean getAndSet(int i) {
195+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
196+
}
197+
198+
@Override
199+
public void clear(int i) {
200+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
201+
}
202+
203+
@Override
204+
public void clear(int startIndex, int endIndex) {
205+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
206+
}
207+
208+
@Override
209+
public int cardinality() {
210+
if (inner == null) {
211+
return maxDoc;
212+
}
213+
return inner.cardinality();
214+
}
215+
216+
@Override
217+
public int approximateCardinality() {
218+
if (inner == null) {
219+
return maxDoc;
220+
}
221+
return inner.approximateCardinality();
222+
}
223+
224+
@Override
225+
public int prevSetBit(int index) {
226+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
227+
}
228+
229+
@Override
230+
public int nextSetBit(int start, int end) {
231+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
232+
}
233+
234+
@Override
235+
public long ramBytesUsed() {
236+
throw new UnsupportedOperationException("not supported on TimeOutCheckingBitSet");
237+
}
238+
239+
@Override
240+
public boolean get(int index) {
241+
if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
242+
queryCancellation.checkCancelled();
243+
}
244+
if (inner == null) {
245+
// if acceptDocs is null, we assume all docs are accepted
246+
return index >= 0 && index < maxDoc;
247+
}
248+
return inner.get(index);
249+
}
250+
251+
@Override
252+
public int length() {
253+
if (inner == null) {
254+
// if acceptDocs is null, we assume all docs are accepted
255+
return maxDoc;
256+
}
257+
return 0;
258+
}
167259
}
168260

169261
private class TimeOutCheckingBits implements Bits {
170262
private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
171263
private final Bits updatedAcceptDocs;
172264
private int calls;
173265

174-
TimeOutCheckingBits(Bits acceptDocs) {
266+
private TimeOutCheckingBits(Bits acceptDocs) {
175267
// when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
176268
// match all docs to allow timeout checking.
177269
this.updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;

server/src/main/java/org/elasticsearch/search/profile/query/QueryProfileShardResult.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,8 @@ public int hashCode() {
137137
public String toString() {
138138
return Strings.toString(this);
139139
}
140+
141+
public Long getVectorOperationsCount() {
142+
return vectorOperationsCount;
143+
}
140144
}

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import java.io.IOException;
3838
import java.util.ArrayList;
39+
import java.util.Collections;
3940
import java.util.List;
4041
import java.util.Locale;
4142
import java.util.Objects;
@@ -53,7 +54,11 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
5354

5455
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
5556

56-
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
57+
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {
58+
public static RetrieverSource from(RetrieverBuilder retriever) {
59+
return new RetrieverSource(retriever, null);
60+
}
61+
}
5762

5863
protected final int rankWindowSize;
5964
protected final List<RetrieverSource> innerRetrievers;
@@ -65,7 +70,7 @@ protected CompoundRetrieverBuilder(List<RetrieverSource> innerRetrievers, int ra
6570

6671
@SuppressWarnings("unchecked")
6772
public T addChild(RetrieverBuilder retrieverBuilder) {
68-
innerRetrievers.add(new RetrieverSource(retrieverBuilder, null));
73+
innerRetrievers.add(RetrieverSource.from(retrieverBuilder));
6974
return (T) this;
7075
}
7176

@@ -99,6 +104,11 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
99104
throw new IllegalStateException("PIT is required");
100105
}
101106

107+
RetrieverBuilder rewritten = doRewrite(ctx);
108+
if (rewritten != this) {
109+
return rewritten;
110+
}
111+
102112
// Rewrite prefilters
103113
// We eagerly rewrite prefilters, because some of the innerRetrievers
104114
// could be compound too, so we want to propagate all the necessary filter information to them
@@ -121,7 +131,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
121131
}
122132
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
123133
if (newRetriever != entry.retriever) {
124-
newRetrievers.add(new RetrieverSource(newRetriever, null));
134+
newRetrievers.add(RetrieverSource.from(newRetriever));
125135
hasChanged |= true;
126136
} else {
127137
var sourceBuilder = entry.source != null
@@ -291,6 +301,10 @@ public int rankWindowSize() {
291301
return rankWindowSize;
292302
}
293303

304+
public List<RetrieverSource> innerRetrievers() {
305+
return Collections.unmodifiableList(innerRetrievers);
306+
}
307+
294308
protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
295309
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
296310
.trackTotalHits(false)
@@ -317,6 +331,16 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu
317331
return sourceBuilder;
318332
}
319333

334+
/**
335+
* Perform any custom rewrite logic necessary
336+
*
337+
* @param ctx The query rewrite context
338+
* @return RetrieverBuilder the rewritten retriever
339+
*/
340+
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
341+
return this;
342+
}
343+
320344
private RankDoc[] getRankDocs(SearchResponse searchResponse) {
321345
int size = searchResponse.getHits().getHits().length;
322346
RankDoc[] docs = new RankDoc[size];

server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public static RescorerRetrieverBuilder fromXContent(XContentParser parser, Retri
7878
private final List<RescorerBuilder<?>> rescorers;
7979

8080
public RescorerRetrieverBuilder(RetrieverBuilder retriever, List<RescorerBuilder<?>> rescorers) {
81-
super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers));
81+
super(List.of(RetrieverSource.from(retriever)), extractMinWindowSize(rescorers));
8282
if (rescorers.isEmpty()) {
8383
throw new IllegalArgumentException("Missing rescore definition");
8484
}

0 commit comments

Comments
 (0)