Skip to content

Commit d3937c1

Browse files
committed
Perform inference on remote cluster when necessary
1 parent cd878bd commit d3937c1

File tree

1 file changed

+55
-42
lines changed

1 file changed

+55
-42
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
5353
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
5454

55-
// TODO: Add flag to perform inference again during remote cluster coordinator rewrite
55+
// TODO: Remove noInferenceResults
5656

5757
public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuilder> {
5858
public static final String NAME = "semantic";
@@ -242,7 +242,9 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
242242
}
243243

244244
private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
245-
if (embeddingsProvider != null || noInferenceResults) {
245+
// Check that we are performing a coordinator node rewrite
246+
// TODO: Clean up how we perform this check
247+
if (queryRewriteContext.getClass() != QueryRewriteContext.class) {
246248
return this;
247249
}
248250

@@ -257,54 +259,65 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
257259
}
258260
}
259261

260-
Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
261-
MapEmbeddingsProvider mapEmbeddingsProvider = new MapEmbeddingsProvider();
262-
263-
// The inference ID set can be empty if either the field name or index name(s) are invalid (or both).
264-
// If this happens, we set the "no inference results" flag to true so the rewrite process can continue.
265-
// Invalid index names will be handled in the transport layer, when the query is sent to the shard.
266-
// Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings.
267-
boolean noInferenceResults = inferenceIds.isEmpty();
268-
269-
for (String inferenceId : inferenceIds) {
270-
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
271-
TaskType.ANY,
272-
inferenceId,
273-
null,
274-
null,
275-
null,
276-
List.of(query),
277-
Map.of(),
278-
InputType.INTERNAL_SEARCH,
279-
null,
280-
false
281-
);
262+
MapEmbeddingsProvider currentEmbeddingsProvider;
263+
if (embeddingsProvider != null) {
264+
if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider) {
265+
currentEmbeddingsProvider = mapEmbeddingsProvider;
266+
} else {
267+
throw new IllegalStateException("Current embeddings provider should be a MapEmbeddingsProvider");
268+
}
269+
} else {
270+
currentEmbeddingsProvider = new MapEmbeddingsProvider();
271+
}
282272

273+
boolean modified = false;
274+
if (queryRewriteContext.hasAsyncActions() == false) {
283275
ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get();
284276
if (modelRegistry == null) {
285277
throw new IllegalStateException("Model registry has not been set");
286278
}
287279

288-
MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId);
289-
InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings);
290-
queryRewriteContext.registerAsyncAction(
291-
(client, listener) -> executeAsyncWithOrigin(
292-
client,
293-
ML_ORIGIN,
294-
InferenceAction.INSTANCE,
295-
inferenceRequest,
296-
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
297-
mapEmbeddingsProvider.addEmbeddings(
298-
inferenceEndpointKey,
299-
validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId)
300-
);
301-
l.onResponse(null);
302-
})
303-
)
304-
);
280+
Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
281+
for (String inferenceId : inferenceIds) {
282+
MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId);
283+
InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings);
284+
285+
if (currentEmbeddingsProvider.getEmbeddings(inferenceEndpointKey) == null) {
286+
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
287+
TaskType.ANY,
288+
inferenceId,
289+
null,
290+
null,
291+
null,
292+
List.of(query),
293+
Map.of(),
294+
InputType.INTERNAL_SEARCH,
295+
null,
296+
false
297+
);
298+
299+
queryRewriteContext.registerAsyncAction(
300+
(client, listener) -> executeAsyncWithOrigin(
301+
client,
302+
ML_ORIGIN,
303+
InferenceAction.INSTANCE,
304+
inferenceRequest,
305+
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
306+
currentEmbeddingsProvider.addEmbeddings(
307+
inferenceEndpointKey,
308+
validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId)
309+
);
310+
l.onResponse(null);
311+
})
312+
)
313+
);
314+
315+
modified = true;
316+
}
317+
}
305318
}
306319

307-
return new SemanticQueryBuilder(this, noInferenceResults ? null : mapEmbeddingsProvider, noInferenceResults);
320+
return modified ? new SemanticQueryBuilder(this, currentEmbeddingsProvider, false) : this;
308321
}
309322

310323
private static InferenceResults validateAndConvertInferenceResults(

0 commit comments

Comments
 (0)