Skip to content

Commit c66eb2e

Browse files
authored
CCS Compatible Semantic Query Rewrite Interceptors (elastic#134507)
1 parent 5577029 commit c66eb2e

27 files changed

+2901
-437
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ static TransportVersion def(int id) {
325325
public static final TransportVersion MAX_HEAP_SIZE_PER_NODE_IN_CLUSTER_INFO = def(9_159_0_00);
326326
public static final TransportVersion TIMESERIES_DEFAULT_LIMIT = def(9_160_0_00);
327327
public static final TransportVersion INFERENCE_API_OPENAI_HEADERS = def(9_161_0_00);
328+
public static final TransportVersion NEW_SEMANTIC_QUERY_INTERCEPTORS = def(9_162_0_00);
328329

329330
/*
330331
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/search/vectors/VectorData.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030

3131
public record VectorData(float[] floatVector, byte[] byteVector) implements Writeable, ToXContentFragment {
3232

33-
private VectorData(float[] floatVector) {
33+
public VectorData(float[] floatVector) {
3434
this(floatVector, null);
3535
}
3636

37-
private VectorData(byte[] byteVector) {
37+
public VectorData(byte[] byteVector) {
3838
this(null, byteVector);
3939
}
4040

x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/KnnSemanticTextTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void testKnnQueryOnSparseSemanticTextField() throws IOException {
9191

9292
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(knnQuery));
9393
assertThat(re.getResponse().getStatusLine().getStatusCode(), is(BAD_REQUEST.getStatus()));
94-
assertThat(re.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
94+
assertThat(re.getMessage(), containsString("Field [sparse_semantic] does not use a [text_embedding] model"));
9595
}
9696

9797
@Before

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.features.NodeFeature;
1212
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
1313
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
14+
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder;
1415
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
1516
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
1617

@@ -22,10 +23,10 @@
2223
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS;
2324
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS;
2425
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG;
25-
import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX;
26-
import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
27-
import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
28-
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
26+
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX;
27+
import static org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
28+
import static org.elasticsearch.xpack.inference.queries.LegacySemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
29+
import static org.elasticsearch.xpack.inference.queries.LegacySemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2930
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.RERANK_SNIPPETS;
3031
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_SNIPPETS;
3132

@@ -85,7 +86,8 @@ public Set<NodeFeature> getTestFeatures() {
8586
SEMANTIC_TEXT_SPARSE_VECTOR_INDEX_OPTIONS,
8687
SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT,
8788
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
88-
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX
89+
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
90+
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS
8991
)
9092
);
9193
if (RERANK_SNIPPETS.isEnabled()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.features.NodeFeature;
2828
import org.elasticsearch.index.mapper.Mapper;
2929
import org.elasticsearch.index.mapper.MetadataFieldMapper;
30+
import org.elasticsearch.index.query.QueryBuilder;
3031
import org.elasticsearch.indices.SystemIndexDescriptor;
3132
import org.elasticsearch.inference.InferenceServiceExtension;
3233
import org.elasticsearch.inference.InferenceServiceRegistry;
@@ -95,6 +96,9 @@
9596
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
9697
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
9798
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
99+
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceKnnVectorQueryBuilder;
100+
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceMatchQueryBuilder;
101+
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceSparseVectorQueryBuilder;
98102
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
99103
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
100104
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
@@ -446,6 +450,27 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
446450
entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
447451
entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new));
448452
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom));
453+
entries.add(
454+
new NamedWriteableRegistry.Entry(
455+
QueryBuilder.class,
456+
InterceptedInferenceMatchQueryBuilder.NAME,
457+
InterceptedInferenceMatchQueryBuilder::new
458+
)
459+
);
460+
entries.add(
461+
new NamedWriteableRegistry.Entry(
462+
QueryBuilder.class,
463+
InterceptedInferenceKnnVectorQueryBuilder.NAME,
464+
InterceptedInferenceKnnVectorQueryBuilder::new
465+
)
466+
);
467+
entries.add(
468+
new NamedWriteableRegistry.Entry(
469+
QueryBuilder.class,
470+
InterceptedInferenceSparseVectorQueryBuilder.NAME,
471+
InterceptedInferenceSparseVectorQueryBuilder::new
472+
)
473+
);
449474
return entries;
450475
}
451476

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.apache.lucene.search.join.ScoreMode;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.action.ResolvedIndices;
13+
import org.elasticsearch.cluster.metadata.IndexMetadata;
14+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
15+
import org.elasticsearch.common.io.stream.StreamInput;
16+
import org.elasticsearch.index.mapper.MappedFieldType;
17+
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
18+
import org.elasticsearch.index.query.QueryBuilder;
19+
import org.elasticsearch.index.query.QueryBuilders;
20+
import org.elasticsearch.index.query.QueryRewriteContext;
21+
import org.elasticsearch.inference.InferenceResults;
22+
import org.elasticsearch.inference.MinimalServiceSettings;
23+
import org.elasticsearch.inference.TaskType;
24+
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
25+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
26+
import org.elasticsearch.search.vectors.QueryVectorBuilder;
27+
import org.elasticsearch.search.vectors.VectorData;
28+
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
29+
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
30+
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
31+
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
32+
33+
import java.io.IOException;
34+
import java.util.Collection;
35+
import java.util.Map;
36+
37+
public class InterceptedInferenceKnnVectorQueryBuilder extends InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> {
38+
public static final String NAME = "intercepted_inference_knn";
39+
40+
@SuppressWarnings("deprecation")
41+
private static final QueryRewriteInterceptor BWC_INTERCEPTOR = new LegacySemanticKnnVectorQueryRewriteInterceptor();
42+
43+
public InterceptedInferenceKnnVectorQueryBuilder(KnnVectorQueryBuilder originalQuery) {
44+
super(originalQuery);
45+
}
46+
47+
public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOException {
48+
super(in);
49+
}
50+
51+
public InterceptedInferenceKnnVectorQueryBuilder(
52+
InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> other,
53+
Map<String, InferenceResults> inferenceResultsMap
54+
) {
55+
super(other, inferenceResultsMap);
56+
}
57+
58+
@Override
59+
protected Map<String, Float> getFields() {
60+
return Map.of(getField(), 1.0f);
61+
}
62+
63+
@Override
64+
protected String getQuery() {
65+
String query = null;
66+
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
67+
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder) {
68+
query = textEmbeddingQueryVectorBuilder.getModelText();
69+
}
70+
71+
return query;
72+
}
73+
74+
@Override
75+
protected String getInferenceIdOverride() {
76+
return getQueryVectorBuilderModelId();
77+
}
78+
79+
@Override
80+
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
81+
if (originalQuery.queryVector() == null && originalQuery.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder == false) {
82+
// This should never happen because either query vector or query vector builder must be non-null, which is enforced by the
83+
// KnnVectorQueryBuilder constructor. The only query vector builder used in production is TextEmbeddingQueryVectorBuilder,
84+
// thus if it is not this type it is null.
85+
// We could throw here _if_ we add a new query vector builder type and forget to update this class to support it, which would
86+
// be a server-side error.
87+
throw new IllegalStateException(
88+
"No [" + TextEmbeddingQueryVectorBuilder.NAME + "] query vector builder or query vector specified"
89+
);
90+
}
91+
92+
// Check if we are querying any non-inference fields
93+
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
94+
for (IndexMetadata indexMetadata : indexMetadataCollection) {
95+
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(getField());
96+
if (inferenceFieldMetadata == null) {
97+
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
98+
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder
99+
&& textEmbeddingQueryVectorBuilder.getModelId() == null) {
100+
throw new IllegalArgumentException("[model_id] must not be null.");
101+
}
102+
}
103+
}
104+
}
105+
106+
@Override
107+
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
108+
QueryBuilder rewritten = this;
109+
if (queryRewriteContext.getMinTransportVersion().before(TransportVersions.NEW_SEMANTIC_QUERY_INTERCEPTORS)) {
110+
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
111+
}
112+
113+
return rewritten;
114+
}
115+
116+
@Override
117+
protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
118+
return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap);
119+
}
120+
121+
@Override
122+
protected QueryBuilder queryFields(
123+
Map<String, Float> inferenceFields,
124+
Map<String, Float> nonInferenceFields,
125+
QueryRewriteContext indexMetadataContext
126+
) {
127+
QueryBuilder rewritten;
128+
MappedFieldType fieldType = indexMetadataContext.getFieldType(getField());
129+
if (fieldType == null) {
130+
rewritten = new MatchNoneQueryBuilder();
131+
} else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
132+
rewritten = querySemanticTextField(semanticTextFieldType);
133+
} else {
134+
rewritten = queryNonSemanticTextField();
135+
}
136+
137+
return rewritten;
138+
}
139+
140+
@Override
141+
protected boolean resolveWildcards() {
142+
return false;
143+
}
144+
145+
@Override
146+
protected boolean useDefaultFields() {
147+
return false;
148+
}
149+
150+
@Override
151+
public String getWriteableName() {
152+
return NAME;
153+
}
154+
155+
private String getField() {
156+
return originalQuery.getFieldName();
157+
}
158+
159+
private String getQueryVectorBuilderModelId() {
160+
String modelId = null;
161+
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
162+
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder) {
163+
modelId = textEmbeddingQueryVectorBuilder.getModelId();
164+
}
165+
166+
return modelId;
167+
}
168+
169+
private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
170+
MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
171+
if (modelSettings == null) {
172+
// No inference results have been indexed yet
173+
return new MatchNoneQueryBuilder();
174+
} else if (modelSettings.taskType() != TaskType.TEXT_EMBEDDING) {
175+
throw new IllegalArgumentException("Field [" + getField() + "] does not use a [" + TaskType.TEXT_EMBEDDING + "] model");
176+
}
177+
178+
VectorData queryVector = originalQuery.queryVector();
179+
if (queryVector == null) {
180+
String inferenceId = getQueryVectorBuilderModelId();
181+
if (inferenceId == null) {
182+
inferenceId = semanticTextFieldType.getSearchInferenceId();
183+
}
184+
185+
MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(inferenceId);
186+
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
187+
}
188+
189+
KnnVectorQueryBuilder innerKnnQuery = new KnnVectorQueryBuilder(
190+
SemanticTextField.getEmbeddingsFieldName(getField()),
191+
queryVector,
192+
originalQuery.k(),
193+
originalQuery.numCands(),
194+
originalQuery.visitPercentage(),
195+
originalQuery.rescoreVectorBuilder(),
196+
originalQuery.getVectorSimilarity()
197+
);
198+
innerKnnQuery.addFilterQueries(originalQuery.filterQueries());
199+
200+
return QueryBuilders.nestedQuery(SemanticTextField.getChunksFieldName(getField()), innerKnnQuery, ScoreMode.Max)
201+
.boost(originalQuery.boost())
202+
.queryName(originalQuery.queryName());
203+
}
204+
205+
private QueryBuilder queryNonSemanticTextField() {
206+
VectorData queryVector = originalQuery.queryVector();
207+
if (queryVector == null) {
208+
String modelId = getQueryVectorBuilderModelId();
209+
if (modelId == null) {
210+
// This should never happen because we validate that either query vector or a valid query vector builder is specified in:
211+
// - The KnnVectorQueryBuilder constructor
212+
// - coordinatorNodeValidate
213+
throw new IllegalStateException("No query vector or query vector builder model ID specified");
214+
}
215+
216+
MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(modelId);
217+
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
218+
}
219+
220+
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(
221+
getField(),
222+
queryVector,
223+
originalQuery.k(),
224+
originalQuery.numCands(),
225+
originalQuery.visitPercentage(),
226+
originalQuery.rescoreVectorBuilder(),
227+
originalQuery.getVectorSimilarity()
228+
).boost(originalQuery.boost()).queryName(originalQuery.queryName());
229+
knnQuery.addFilterQueries(originalQuery.filterQueries());
230+
231+
return knnQuery;
232+
}
233+
234+
private MlTextEmbeddingResults getTextEmbeddingResults(String inferenceId) {
235+
InferenceResults inferenceResults = inferenceResultsMap.get(inferenceId);
236+
if (inferenceResults == null) {
237+
throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
238+
} else if (inferenceResults instanceof MlTextEmbeddingResults == false) {
239+
throw new IllegalArgumentException(
240+
"Expected query inference results to be of type ["
241+
+ MlTextEmbeddingResults.NAME
242+
+ "], got ["
243+
+ inferenceResults.getWriteableName()
244+
+ "]. Are you specifying a compatible inference endpoint? Has the inference endpoint configuration changed?"
245+
);
246+
}
247+
248+
return (MlTextEmbeddingResults) inferenceResults;
249+
}
250+
}

0 commit comments

Comments
 (0)