Skip to content

Commit 637e498

Browse files
Intercept filters to knn queries (#138457)
`knn` queries can have filter queries. Those filters may contain semantic queries (e.g. `knn`, `sparse_vector`, or `match` queries targetting `semantic_text` fields). We also need to intercept those in order to perform inference for their `query` field during the coordinator node rewrite. This commit achieves this with the following changes: - `SemanticKnnVectorQueryRewriteInterceptor` attempts to rewrite filter queries. This means we rely on a rewrite cycle to attempt to intercept every query down the tree of each filter query. - `InterceptedInferenceQueryBuilder` now has a `customCoordinatorNodeRewrite` method that subclasses can implement to implement additional rewriting logic needed, e.g. rewrite inner queries. - `InterceptedInferenceKnnVectorQueryBuilder` implements `customCoordinatorNodeRewrite` so that the filter queries are rewritten. This commit fixes the exceptions throws in #138410. However, searches that contain semantic text queries as filters to a semantic text knn query will return `0` hits due to another issue that is captured in #138184. Closes #138410
1 parent ecd5f9a commit 637e498

File tree

15 files changed

+373
-38
lines changed

15 files changed

+373
-38
lines changed

docs/changelog/138457.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 138457
2+
summary: Intercept filters to knn queries
3+
area: Vector Search
4+
type: bug
5+
issues:
6+
- 138410

server/src/main/java/org/elasticsearch/plugins/internal/rewriter/QueryRewriteInterceptor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.index.query.QueryBuilder;
1313
import org.elasticsearch.index.query.QueryRewriteContext;
1414

15+
import java.io.IOException;
1516
import java.util.Map;
1617

1718
/**
@@ -27,7 +28,7 @@ public interface QueryRewriteInterceptor {
2728
* @param queryBuilder the original {@link QueryBuilder} to potentially rewrite
2829
* @return the rewritten {@link QueryBuilder}, or the original instance if no rewrite was needed
2930
*/
30-
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder);
31+
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException;
3132

3233
/**
3334
* Name of the query to be intercepted and rewritten.
@@ -52,7 +53,7 @@ public String getQueryName() {
5253
}
5354

5455
@Override
55-
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
56+
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException {
5657
QueryRewriteInterceptor interceptor = interceptors.get(queryBuilder.getName());
5758
if (interceptor != null) {
5859
return interceptor.interceptAndRewrite(context, queryBuilder);

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,13 @@ public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries)
357357
return this;
358358
}
359359

360+
public KnnVectorQueryBuilder setFilterQueries(List<QueryBuilder> filterQueries) {
361+
Objects.requireNonNull(filterQueries);
362+
this.filterQueries.clear();
363+
this.filterQueries.addAll(filterQueries);
364+
return this;
365+
}
366+
360367
@Override
361368
protected void doWriteTo(StreamOutput out) throws IOException {
362369
if (queryVectorSupplier != null) {

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.index.query.QueryBuilders;
3333
import org.elasticsearch.index.query.QueryRewriteContext;
3434
import org.elasticsearch.index.query.QueryShardException;
35+
import org.elasticsearch.index.query.RandomQueryBuilder;
3536
import org.elasticsearch.index.query.Rewriteable;
3637
import org.elasticsearch.index.query.SearchExecutionContext;
3738
import org.elasticsearch.index.query.TermQueryBuilder;
@@ -565,4 +566,11 @@ public void testRewriteWithQueryVectorBuilder() throws Exception {
565566
assertThat(rewritten.filterQueries(), hasSize(numFilters));
566567
assertThat(rewritten.filterQueries(), equalTo(filters));
567568
}
569+
570+
public void testSetFilterQueries() {
571+
KnnVectorQueryBuilder knnQueryBuilder = doCreateTestQueryBuilder();
572+
List<QueryBuilder> newFilters = randomList(5, () -> RandomQueryBuilder.createQuery(random()));
573+
knnQueryBuilder.setFilterQueries(newFilters);
574+
assertThat(knnQueryBuilder.filterQueries(), equalTo(newFilters));
575+
}
568576
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
1414
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
1515
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder;
16+
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
1617
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
1718
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
1819

@@ -109,6 +110,7 @@ public Set<NodeFeature> getTestFeatures() {
109110
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
110111
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
111112
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS,
113+
SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_FILTERS_REWRITE_INTERCEPTION_SUPPORTED,
112114
TEXT_SIMILARITY_RERANKER_SNIPPETS,
113115
ModelStats.SEMANTIC_TEXT_USAGE,
114116
SEARCH_USAGE_EXTENDED_DATA,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
3333

3434
import java.io.IOException;
35+
import java.util.ArrayList;
3536
import java.util.Collection;
37+
import java.util.List;
3638
import java.util.Map;
3739

3840
import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
@@ -91,6 +93,33 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() {
9193
return modelId != null ? new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, modelId) : null;
9294
}
9395

96+
@Override
97+
protected InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> customDoRewriteGetInferenceResults(
98+
QueryRewriteContext queryRewriteContext
99+
) throws IOException {
100+
// knn query may contain filters that are also intercepted.
101+
// We need to rewrite those here so that we can get inference results for them too.
102+
return rewriteFilterQueries(queryRewriteContext);
103+
}
104+
105+
private InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> rewriteFilterQueries(QueryRewriteContext queryRewriteContext)
106+
throws IOException {
107+
boolean filtersChanged = false;
108+
List<QueryBuilder> rewrittenFilters = new ArrayList<>(originalQuery.filterQueries().size());
109+
for (QueryBuilder filter : originalQuery.filterQueries()) {
110+
QueryBuilder rewrittenFilter = filter.rewrite(queryRewriteContext);
111+
if (rewrittenFilter != filter) {
112+
filtersChanged = true;
113+
}
114+
rewrittenFilters.add(rewrittenFilter);
115+
}
116+
if (filtersChanged) {
117+
originalQuery.setFilterQueries(rewrittenFilters);
118+
return copy(inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
119+
}
120+
return this;
121+
}
122+
94123
@Override
95124
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
96125
if (originalQuery.queryVector() == null && originalQuery.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder == false) {
@@ -119,7 +148,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
119148
}
120149

121150
@Override
122-
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
151+
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
123152
QueryBuilder rewritten = this;
124153
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
125154
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
@@ -129,7 +158,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
129158
}
130159

131160
@Override
132-
protected QueryBuilder copy(
161+
protected InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> copy(
133162
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
134163
SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier,
135164
boolean ccsRequest

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ protected String getQuery() {
6565
}
6666

6767
@Override
68-
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
68+
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
6969
QueryBuilder rewritten = this;
7070
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
7171
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
@@ -75,7 +75,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
7575
}
7676

7777
@Override
78-
protected QueryBuilder copy(
78+
protected InterceptedInferenceQueryBuilder<MatchQueryBuilder> copy(
7979
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
8080
SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier,
8181
boolean ccsRequest

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ protected InterceptedInferenceQueryBuilder(
152152
* @param queryRewriteContext The query rewrite context
153153
* @return The query builder rewritten to a backwards-compatible form
154154
*/
155-
protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext);
155+
protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException;
156156

157157
/**
158158
* Generate a copy of {@code this}.
159159
*
160-
* @param inferenceResultsMap The inference results map
160+
* @param inferenceResultsMap The inference results map
161161
* @param inferenceResultsMapSupplier The inference results map supplier
162-
* @param ccsRequest Flag indicating if this is a CCS request
162+
* @param ccsRequest Flag indicating if this is a CCS request
163163
* @return A copy of {@code this} with the provided inference results map
164164
*/
165165
protected abstract QueryBuilder copy(
@@ -209,6 +209,15 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() {
209209
*/
210210
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {}
211211

212+
/**
213+
* A hook for subclasses to do additional rewriting and inference result fetching while we are on the coordinator node.
214+
* An example usage is {@link InterceptedInferenceKnnVectorQueryBuilder} which needs to rewrite the knn queries filters.
215+
*/
216+
protected InterceptedInferenceQueryBuilder<T> customDoRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext)
217+
throws IOException {
218+
return this;
219+
}
220+
212221
@Override
213222
protected void doWriteTo(StreamOutput out) throws IOException {
214223
if (inferenceResultsMapSupplier != null) {
@@ -304,7 +313,7 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex
304313
return queryFields(inferenceFieldsToQuery, nonInferenceFieldsToQuery, indexMetadataContext);
305314
}
306315

307-
private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
316+
private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) throws IOException {
308317
QueryBuilder rewrittenBwC = doRewriteBwC(queryRewriteContext);
309318
if (rewrittenBwC != this) {
310319
return rewrittenBwC;
@@ -344,6 +353,15 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
344353
);
345354
}
346355

356+
InterceptedInferenceQueryBuilder<T> rewritten = customDoRewriteGetInferenceResults(queryRewriteContext);
357+
return rewritten.doRewriteWaitForInferenceResults(queryRewriteContext, inferenceIds, ccsRequest);
358+
}
359+
360+
private QueryBuilder doRewriteWaitForInferenceResults(
361+
QueryRewriteContext queryRewriteContext,
362+
Set<FullyQualifiedInferenceId> inferenceIds,
363+
boolean ccsRequest
364+
) {
347365
if (inferenceResultsMapSupplier != null) {
348366
// Additional inference results have already been requested, and we are waiting for them to continue the rewrite process
349367
return getNewInferenceResultsFromSupplier(inferenceResultsMapSupplier, this, m -> copy(m, null, ccsRequest));
@@ -376,7 +394,6 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
376394
} else {
377395
rewritten = copy(inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest);
378396
}
379-
380397
return rewritten;
381398
}
382399

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
106106
}
107107

108108
@Override
109-
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
109+
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
110110
QueryBuilder rewritten = this;
111111
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
112112
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
@@ -116,7 +116,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
116116
}
117117

118118
@Override
119-
protected QueryBuilder copy(
119+
protected InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> copy(
120120
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
121121
SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier,
122122
boolean ccsRequest

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,50 @@
77

88
package org.elasticsearch.xpack.inference.queries;
99

10+
import org.elasticsearch.features.NodeFeature;
1011
import org.elasticsearch.index.query.QueryBuilder;
1112
import org.elasticsearch.index.query.QueryRewriteContext;
1213
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
1314
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
1415

16+
import java.io.IOException;
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
1520
public class SemanticKnnVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {
21+
22+
public static final NodeFeature SEMANTIC_KNN_VECTOR_QUERY_FILTERS_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
23+
"search.semantic_knn_vector_query_filters_rewrite_interception_supported"
24+
);
25+
1626
@Override
17-
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
27+
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException {
1828
if (queryBuilder instanceof KnnVectorQueryBuilder knnVectorQueryBuilder) {
19-
return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder);
29+
return interceptKnnQuery(context, knnVectorQueryBuilder);
2030
} else {
2131
throw new IllegalStateException("Unexpected query builder type: " + queryBuilder.getClass());
2232
}
2333
}
2434

35+
private static InterceptedInferenceKnnVectorQueryBuilder interceptKnnQuery(
36+
QueryRewriteContext context,
37+
KnnVectorQueryBuilder knnVectorQueryBuilder
38+
) throws IOException {
39+
boolean changed = false;
40+
List<QueryBuilder> rewrittenFilters = new ArrayList<>(knnVectorQueryBuilder.filterQueries().size());
41+
for (QueryBuilder filter : knnVectorQueryBuilder.filterQueries()) {
42+
QueryBuilder rewritten = filter.rewrite(context);
43+
if (rewritten != filter) {
44+
changed = true;
45+
}
46+
rewrittenFilters.add(rewritten);
47+
}
48+
if (changed) {
49+
knnVectorQueryBuilder.setFilterQueries(rewrittenFilters);
50+
}
51+
return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder);
52+
}
53+
2554
@Override
2655
public String getQueryName() {
2756
return KnnVectorQueryBuilder.NAME;

0 commit comments

Comments
 (0)