Skip to content

Commit 200b08a

Browse files
authored
Allow semantic queries to gather inference results on remote clusters (#134956)
1 parent 78858be commit 200b08a

12 files changed

+418
-127
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ static TransportVersion def(int id) {
333333
public static final TransportVersion ESQL_LOOKUP_JOIN_ON_EXPRESSION = def(9_163_0_00);
334334
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING_REMOVED = def(9_164_0_00);
335335
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_INFERENCE_FIELDS_PARAM = def(9_165_0_00);
336+
public static final TransportVersion INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS = def(9_166_0_00);
336337

337338
/*
338339
* STOP! READ THIS FIRST! No, really,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
14+
import java.io.IOException;
15+
import java.util.Objects;
16+
17+
public record FullyQualifiedInferenceId(String clusterAlias, String inferenceId) implements Writeable {
18+
public FullyQualifiedInferenceId(String clusterAlias, String inferenceId) {
19+
this.clusterAlias = Objects.requireNonNull(clusterAlias);
20+
this.inferenceId = Objects.requireNonNull(inferenceId);
21+
}
22+
23+
public FullyQualifiedInferenceId(StreamInput in) throws IOException {
24+
this(in.readString(), in.readString());
25+
}
26+
27+
@Override
28+
public void writeTo(StreamOutput out) throws IOException {
29+
out.writeString(clusterAlias);
30+
out.writeString(inferenceId);
31+
}
32+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOExcept
4848
super(in);
4949
}
5050

51-
public InterceptedInferenceKnnVectorQueryBuilder(
51+
InterceptedInferenceKnnVectorQueryBuilder(
5252
InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> other,
53-
Map<String, InferenceResults> inferenceResultsMap
53+
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
5454
) {
5555
super(other, inferenceResultsMap);
5656
}
@@ -114,7 +114,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
114114
}
115115

116116
@Override
117-
protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
117+
protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
118118
return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap);
119119
}
120120

@@ -129,9 +129,9 @@ protected QueryBuilder queryFields(
129129
if (fieldType == null) {
130130
rewritten = new MatchNoneQueryBuilder();
131131
} else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
132-
rewritten = querySemanticTextField(semanticTextFieldType);
132+
rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
133133
} else {
134-
rewritten = queryNonSemanticTextField();
134+
rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias());
135135
}
136136

137137
return rewritten;
@@ -166,7 +166,7 @@ private String getQueryVectorBuilderModelId() {
166166
return modelId;
167167
}
168168

169-
private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
169+
private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
170170
MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
171171
if (modelSettings == null) {
172172
// No inference results have been indexed yet
@@ -182,7 +182,7 @@ private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticText
182182
inferenceId = semanticTextFieldType.getSearchInferenceId();
183183
}
184184

185-
MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(inferenceId);
185+
MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, inferenceId);
186186
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
187187
}
188188

@@ -202,7 +202,7 @@ private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticText
202202
.queryName(originalQuery.queryName());
203203
}
204204

205-
private QueryBuilder queryNonSemanticTextField() {
205+
private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
206206
VectorData queryVector = originalQuery.queryVector();
207207
if (queryVector == null) {
208208
String modelId = getQueryVectorBuilderModelId();
@@ -213,7 +213,7 @@ private QueryBuilder queryNonSemanticTextField() {
213213
throw new IllegalStateException("No query vector or query vector builder model ID specified");
214214
}
215215

216-
MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(modelId);
216+
MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(clusterAlias, modelId);
217217
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
218218
}
219219

@@ -231,8 +231,8 @@ private QueryBuilder queryNonSemanticTextField() {
231231
return knnQuery;
232232
}
233233

234-
private MlTextEmbeddingResults getTextEmbeddingResults(String inferenceId) {
235-
InferenceResults inferenceResults = inferenceResultsMap.get(inferenceId);
234+
private MlTextEmbeddingResults getTextEmbeddingResults(String clusterAlias, String inferenceId) {
235+
InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId));
236236
if (inferenceResults == null) {
237237
throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
238238
} else if (inferenceResults instanceof MlTextEmbeddingResults == false) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ public InterceptedInferenceMatchQueryBuilder(StreamInput in) throws IOException
3535
super(in);
3636
}
3737

38-
private InterceptedInferenceMatchQueryBuilder(
38+
InterceptedInferenceMatchQueryBuilder(
3939
InterceptedInferenceQueryBuilder<MatchQueryBuilder> other,
40-
Map<String, InferenceResults> inferenceResultsMap
40+
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
4141
) {
4242
super(other, inferenceResultsMap);
4343
}
@@ -63,7 +63,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
6363
}
6464

6565
@Override
66-
protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
66+
protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
6767
return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap);
6868
}
6969

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@
3737
import java.util.Map;
3838
import java.util.Objects;
3939
import java.util.Set;
40-
import java.util.concurrent.ConcurrentHashMap;
4140

41+
import static org.elasticsearch.TransportVersions.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS;
4242
import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING;
43+
import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
44+
import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap;
4345

4446
/**
4547
* <p>
@@ -60,7 +62,7 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
6062
public static final NodeFeature NEW_SEMANTIC_QUERY_INTERCEPTORS = new NodeFeature("search.new_semantic_query_interceptors");
6163

6264
protected final T originalQuery;
63-
protected final Map<String, InferenceResults> inferenceResultsMap;
65+
protected final Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap;
6466

6567
protected InterceptedInferenceQueryBuilder(T originalQuery) {
6668
Objects.requireNonNull(originalQuery, "original query must not be null");
@@ -72,12 +74,20 @@ protected InterceptedInferenceQueryBuilder(T originalQuery) {
7274
protected InterceptedInferenceQueryBuilder(StreamInput in) throws IOException {
7375
super(in);
7476
this.originalQuery = (T) in.readNamedWriteable(QueryBuilder.class);
75-
this.inferenceResultsMap = in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)));
77+
if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
78+
this.inferenceResultsMap = in.readOptional(
79+
i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> i2.readNamedWriteable(InferenceResults.class))
80+
);
81+
} else {
82+
this.inferenceResultsMap = convertFromBwcInferenceResultsMap(
83+
in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)))
84+
);
85+
}
7686
}
7787

7888
protected InterceptedInferenceQueryBuilder(
7989
InterceptedInferenceQueryBuilder<T> other,
80-
Map<String, InferenceResults> inferenceResultsMap
90+
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
8191
) {
8292
this.originalQuery = other.originalQuery;
8393
this.inferenceResultsMap = inferenceResultsMap;
@@ -122,7 +132,7 @@ protected InterceptedInferenceQueryBuilder(
122132
* @param inferenceResultsMap The inference results map
123133
* @return A copy of {@code this} with the provided inference results map
124134
*/
125-
protected abstract QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap);
135+
protected abstract QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap);
126136

127137
/**
128138
* Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use
@@ -168,7 +178,19 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {}
168178
@Override
169179
protected void doWriteTo(StreamOutput out) throws IOException {
170180
out.writeNamedWriteable(originalQuery);
171-
out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeNamedWriteable), inferenceResultsMap);
181+
if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
182+
out.writeOptional(
183+
(o, v) -> o.writeMap(v, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable),
184+
inferenceResultsMap
185+
);
186+
} else {
187+
out.writeOptional((o1, v) -> o1.writeMap(v, (o2, id) -> {
188+
if (id.clusterAlias().equals(LOCAL_CLUSTER_GROUP_KEY) == false) {
189+
throw new IllegalArgumentException("Cannot serialize remote cluster inference results in a mixed-version cluster");
190+
}
191+
o2.writeString(id.inferenceId());
192+
}, StreamOutput::writeNamedWriteable), inferenceResultsMap);
193+
}
172194
}
173195

174196
@Override
@@ -227,11 +249,6 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex
227249
}
228250

229251
private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
230-
if (this.inferenceResultsMap != null) {
231-
inferenceResultsErrorCheck(this.inferenceResultsMap);
232-
return this;
233-
}
234-
235252
QueryBuilder rewrittenBwC = doRewriteBwC(queryRewriteContext);
236253
if (rewrittenBwC != this) {
237254
return rewrittenBwC;
@@ -271,17 +288,27 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
271288
inferenceIds = Set.of(inferenceIdOverride);
272289
}
273290

274-
// If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results are
275-
// provided by the user.
276-
String query = getQuery();
277-
Map<String, InferenceResults> inferenceResultsMap = new ConcurrentHashMap<>();
278-
if (query != null) {
279-
for (String inferenceId : inferenceIds) {
280-
SemanticQueryBuilder.registerInferenceAsyncAction(queryRewriteContext, inferenceResultsMap, query, inferenceId);
291+
QueryBuilder rewritten = this;
292+
if (queryRewriteContext.hasAsyncActions() == false) {
293+
// If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results
294+
// are provided by the user. Ensure that we set an empty inference results map in this case so that it is always non-null after
295+
// coordinator node rewrite.
296+
Map<FullyQualifiedInferenceId, InferenceResults> modifiedInferenceResultsMap = SemanticQueryBuilder.getInferenceResults(
297+
queryRewriteContext,
298+
inferenceIds,
299+
this.inferenceResultsMap,
300+
getQuery()
301+
);
302+
303+
if (modifiedInferenceResultsMap == this.inferenceResultsMap) {
304+
// The inference results map is fully populated, so we can perform error checking
305+
inferenceResultsErrorCheck(modifiedInferenceResultsMap);
306+
} else {
307+
rewritten = copy(modifiedInferenceResultsMap);
281308
}
282309
}
283310

284-
return copy(inferenceResultsMap);
311+
return rewritten;
285312
}
286313

287314
private static Set<String> getInferenceIdsForFields(
@@ -360,9 +387,9 @@ private static void addToInferenceFieldsMap(Map<String, Float> inferenceFields,
360387
inferenceFields.compute(field, (k, v) -> v == null ? weight : v * weight);
361388
}
362389

363-
private static void inferenceResultsErrorCheck(Map<String, InferenceResults> inferenceResultsMap) {
390+
private static void inferenceResultsErrorCheck(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
364391
for (var entry : inferenceResultsMap.entrySet()) {
365-
String inferenceId = entry.getKey();
392+
String inferenceId = entry.getKey().inferenceId();
366393
InferenceResults inferenceResults = entry.getValue();
367394

368395
if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOExc
4747
super(in);
4848
}
4949

50-
public InterceptedInferenceSparseVectorQueryBuilder(
50+
InterceptedInferenceSparseVectorQueryBuilder(
5151
InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> other,
52-
Map<String, InferenceResults> inferenceResultsMap
52+
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
5353
) {
5454
super(other, inferenceResultsMap);
5555
}
@@ -96,7 +96,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
9696
}
9797

9898
@Override
99-
protected QueryBuilder copy(Map<String, InferenceResults> inferenceResultsMap) {
99+
protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
100100
return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap);
101101
}
102102

@@ -111,9 +111,9 @@ protected QueryBuilder queryFields(
111111
if (fieldType == null) {
112112
rewritten = new MatchNoneQueryBuilder();
113113
} else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
114-
rewritten = querySemanticTextField(semanticTextFieldType);
114+
rewritten = querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
115115
} else {
116-
rewritten = queryNonSemanticTextField();
116+
rewritten = queryNonSemanticTextField(indexMetadataContext.getLocalClusterAlias());
117117
}
118118

119119
return rewritten;
@@ -138,7 +138,7 @@ private String getField() {
138138
return originalQuery.getFieldName();
139139
}
140140

141-
private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
141+
private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
142142
MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
143143
if (modelSettings == null) {
144144
// No inference results have been indexed yet
@@ -154,7 +154,7 @@ private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticText
154154
inferenceId = semanticTextFieldType.getSearchInferenceId();
155155
}
156156

157-
queryVector = getQueryVector(inferenceId);
157+
queryVector = getQueryVector(clusterAlias, inferenceId);
158158
}
159159

160160
SparseVectorQueryBuilder innerSparseVectorQuery = new SparseVectorQueryBuilder(
@@ -171,15 +171,15 @@ private QueryBuilder querySemanticTextField(SemanticTextFieldMapper.SemanticText
171171
.queryName(originalQuery.queryName());
172172
}
173173

174-
private QueryBuilder queryNonSemanticTextField() {
174+
private QueryBuilder queryNonSemanticTextField(String clusterAlias) {
175175
List<WeightedToken> queryVector = originalQuery.getQueryVectors();
176176
if (queryVector == null) {
177177
String inferenceId = originalQuery.getInferenceId();
178178
if (inferenceId == null) {
179179
throw new IllegalArgumentException("Either query vector or inference ID must be specified");
180180
}
181181

182-
queryVector = getQueryVector(inferenceId);
182+
queryVector = getQueryVector(clusterAlias, inferenceId);
183183
}
184184

185185
return new SparseVectorQueryBuilder(
@@ -192,8 +192,8 @@ private QueryBuilder queryNonSemanticTextField() {
192192
).boost(originalQuery.boost()).queryName(originalQuery.queryName());
193193
}
194194

195-
private List<WeightedToken> getQueryVector(String inferenceId) {
196-
InferenceResults inferenceResults = inferenceResultsMap.get(inferenceId);
195+
private List<WeightedToken> getQueryVector(String clusterAlias, String inferenceId) {
196+
InferenceResults inferenceResults = inferenceResultsMap.get(new FullyQualifiedInferenceId(clusterAlias, inferenceId));
197197
if (inferenceResults == null) {
198198
throw new IllegalStateException("Could not find inference results from inference endpoint [" + inferenceId + "]");
199199
} else if (inferenceResults instanceof TextExpansionResults == false) {

0 commit comments

Comments
 (0)