5252import static org .elasticsearch .xpack .core .ClientHelper .ML_ORIGIN ;
5353import static org .elasticsearch .xpack .core .ClientHelper .executeAsyncWithOrigin ;
5454
55- // TODO: Add flag to perform inference again during remote cluster coordinator rewrite
55+ // TODO: Remove noInferenceResults
5656
5757public class SemanticQueryBuilder extends AbstractQueryBuilder <SemanticQueryBuilder > {
5858 public static final String NAME = "semantic" ;
@@ -242,7 +242,9 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
242242 }
243243
244244 private SemanticQueryBuilder doRewriteGetInferenceResults (QueryRewriteContext queryRewriteContext ) {
245- if (embeddingsProvider != null || noInferenceResults ) {
245+ // Check that we are performing a coordinator node rewrite
246+ // TODO: Clean up how we perform this check
247+ if (queryRewriteContext .getClass () != QueryRewriteContext .class ) {
246248 return this ;
247249 }
248250
@@ -257,54 +259,65 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
257259 }
258260 }
259261
260- Set <String > inferenceIds = getInferenceIdsForForField (resolvedIndices .getConcreteLocalIndicesMetadata ().values (), fieldName );
261- MapEmbeddingsProvider mapEmbeddingsProvider = new MapEmbeddingsProvider ();
262-
263- // The inference ID set can be empty if either the field name or index name(s) are invalid (or both).
264- // If this happens, we set the "no inference results" flag to true so the rewrite process can continue.
265- // Invalid index names will be handled in the transport layer, when the query is sent to the shard.
266- // Invalid field names will be handled when the query is re-written on the shard, where we have access to the index mappings.
267- boolean noInferenceResults = inferenceIds .isEmpty ();
268-
269- for (String inferenceId : inferenceIds ) {
270- InferenceAction .Request inferenceRequest = new InferenceAction .Request (
271- TaskType .ANY ,
272- inferenceId ,
273- null ,
274- null ,
275- null ,
276- List .of (query ),
277- Map .of (),
278- InputType .INTERNAL_SEARCH ,
279- null ,
280- false
281- );
262+ MapEmbeddingsProvider currentEmbeddingsProvider ;
263+ if (embeddingsProvider != null ) {
264+ if (embeddingsProvider instanceof MapEmbeddingsProvider mapEmbeddingsProvider ) {
265+ currentEmbeddingsProvider = mapEmbeddingsProvider ;
266+ } else {
267+ throw new IllegalStateException ("Current embeddings provider should be a MapEmbeddingsProvider" );
268+ }
269+ } else {
270+ currentEmbeddingsProvider = new MapEmbeddingsProvider ();
271+ }
282272
273+ boolean modified = false ;
274+ if (queryRewriteContext .hasAsyncActions () == false ) {
283275 ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER .get ();
284276 if (modelRegistry == null ) {
285277 throw new IllegalStateException ("Model registry has not been set" );
286278 }
287279
288- MinimalServiceSettings serviceSettings = modelRegistry .getMinimalServiceSettings (inferenceId );
289- InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey (inferenceId , serviceSettings );
290- queryRewriteContext .registerAsyncAction (
291- (client , listener ) -> executeAsyncWithOrigin (
292- client ,
293- ML_ORIGIN ,
294- InferenceAction .INSTANCE ,
295- inferenceRequest ,
296- listener .delegateFailureAndWrap ((l , inferenceResponse ) -> {
297- mapEmbeddingsProvider .addEmbeddings (
298- inferenceEndpointKey ,
299- validateAndConvertInferenceResults (inferenceResponse .getResults (), fieldName , inferenceId )
300- );
301- l .onResponse (null );
302- })
303- )
304- );
280+ Set <String > inferenceIds = getInferenceIdsForForField (resolvedIndices .getConcreteLocalIndicesMetadata ().values (), fieldName );
281+ for (String inferenceId : inferenceIds ) {
282+ MinimalServiceSettings serviceSettings = modelRegistry .getMinimalServiceSettings (inferenceId );
283+ InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey (inferenceId , serviceSettings );
284+
285+ if (currentEmbeddingsProvider .getEmbeddings (inferenceEndpointKey ) == null ) {
286+ InferenceAction .Request inferenceRequest = new InferenceAction .Request (
287+ TaskType .ANY ,
288+ inferenceId ,
289+ null ,
290+ null ,
291+ null ,
292+ List .of (query ),
293+ Map .of (),
294+ InputType .INTERNAL_SEARCH ,
295+ null ,
296+ false
297+ );
298+
299+ queryRewriteContext .registerAsyncAction (
300+ (client , listener ) -> executeAsyncWithOrigin (
301+ client ,
302+ ML_ORIGIN ,
303+ InferenceAction .INSTANCE ,
304+ inferenceRequest ,
305+ listener .delegateFailureAndWrap ((l , inferenceResponse ) -> {
306+ currentEmbeddingsProvider .addEmbeddings (
307+ inferenceEndpointKey ,
308+ validateAndConvertInferenceResults (inferenceResponse .getResults (), fieldName , inferenceId )
309+ );
310+ l .onResponse (null );
311+ })
312+ )
313+ );
314+
315+ modified = true ;
316+ }
317+ }
305318 }
306319
307- return new SemanticQueryBuilder (this , noInferenceResults ? null : mapEmbeddingsProvider , noInferenceResults ) ;
320+ return modified ? new SemanticQueryBuilder (this , currentEmbeddingsProvider , false ) : this ;
308321 }
309322
310323 private static InferenceResults validateAndConvertInferenceResults (
0 commit comments