3737import java .util .Map ;
3838import java .util .Objects ;
3939import java .util .Set ;
40- import java .util .concurrent .ConcurrentHashMap ;
4140
41+ import static org .elasticsearch .TransportVersions .INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS ;
4242import static org .elasticsearch .index .IndexSettings .DEFAULT_FIELD_SETTING ;
43+ import static org .elasticsearch .transport .RemoteClusterAware .LOCAL_CLUSTER_GROUP_KEY ;
44+ import static org .elasticsearch .xpack .inference .queries .SemanticQueryBuilder .convertFromBwcInferenceResultsMap ;
4345
4446/**
4547 * <p>
@@ -60,7 +62,7 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
6062 public static final NodeFeature NEW_SEMANTIC_QUERY_INTERCEPTORS = new NodeFeature ("search.new_semantic_query_interceptors" );
6163
6264 protected final T originalQuery ;
63- protected final Map <String , InferenceResults > inferenceResultsMap ;
65+ protected final Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap ;
6466
6567 protected InterceptedInferenceQueryBuilder (T originalQuery ) {
6668 Objects .requireNonNull (originalQuery , "original query must not be null" );
@@ -72,12 +74,20 @@ protected InterceptedInferenceQueryBuilder(T originalQuery) {
7274 protected InterceptedInferenceQueryBuilder (StreamInput in ) throws IOException {
7375 super (in );
7476 this .originalQuery = (T ) in .readNamedWriteable (QueryBuilder .class );
75- this .inferenceResultsMap = in .readOptional (i1 -> i1 .readImmutableMap (i2 -> i2 .readNamedWriteable (InferenceResults .class )));
77+ if (in .getTransportVersion ().supports (INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS )) {
78+ this .inferenceResultsMap = in .readOptional (
79+ i1 -> i1 .readImmutableMap (FullyQualifiedInferenceId ::new , i2 -> i2 .readNamedWriteable (InferenceResults .class ))
80+ );
81+ } else {
82+ this .inferenceResultsMap = convertFromBwcInferenceResultsMap (
83+ in .readOptional (i1 -> i1 .readImmutableMap (i2 -> i2 .readNamedWriteable (InferenceResults .class )))
84+ );
85+ }
7686 }
7787
7888 protected InterceptedInferenceQueryBuilder (
7989 InterceptedInferenceQueryBuilder <T > other ,
80- Map <String , InferenceResults > inferenceResultsMap
90+ Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap
8191 ) {
8292 this .originalQuery = other .originalQuery ;
8393 this .inferenceResultsMap = inferenceResultsMap ;
@@ -122,7 +132,7 @@ protected InterceptedInferenceQueryBuilder(
122132 * @param inferenceResultsMap The inference results map
123133 * @return A copy of {@code this} with the provided inference results map
124134 */
125- protected abstract QueryBuilder copy (Map <String , InferenceResults > inferenceResultsMap );
135+ protected abstract QueryBuilder copy (Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap );
126136
127137 /**
128138 * Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use
@@ -168,7 +178,19 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {}
168178 @ Override
169179 protected void doWriteTo (StreamOutput out ) throws IOException {
170180 out .writeNamedWriteable (originalQuery );
171- out .writeOptional ((o , v ) -> o .writeMap (v , StreamOutput ::writeNamedWriteable ), inferenceResultsMap );
181+ if (out .getTransportVersion ().supports (INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS )) {
182+ out .writeOptional (
183+ (o , v ) -> o .writeMap (v , StreamOutput ::writeWriteable , StreamOutput ::writeNamedWriteable ),
184+ inferenceResultsMap
185+ );
186+ } else {
187+ out .writeOptional ((o1 , v ) -> o1 .writeMap (v , (o2 , id ) -> {
188+ if (id .clusterAlias ().equals (LOCAL_CLUSTER_GROUP_KEY ) == false ) {
189+ throw new IllegalArgumentException ("Cannot serialize remote cluster inference results in a mixed-version cluster" );
190+ }
191+ o2 .writeString (id .inferenceId ());
192+ }, StreamOutput ::writeNamedWriteable ), inferenceResultsMap );
193+ }
172194 }
173195
174196 @ Override
@@ -227,11 +249,6 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex
227249 }
228250
229251 private QueryBuilder doRewriteGetInferenceResults (QueryRewriteContext queryRewriteContext ) {
230- if (this .inferenceResultsMap != null ) {
231- inferenceResultsErrorCheck (this .inferenceResultsMap );
232- return this ;
233- }
234-
235252 QueryBuilder rewrittenBwC = doRewriteBwC (queryRewriteContext );
236253 if (rewrittenBwC != this ) {
237254 return rewrittenBwC ;
@@ -271,17 +288,27 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
271288 inferenceIds = Set .of (inferenceIdOverride );
272289 }
273290
274- // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results are
275- // provided by the user.
276- String query = getQuery ();
277- Map <String , InferenceResults > inferenceResultsMap = new ConcurrentHashMap <>();
278- if (query != null ) {
279- for (String inferenceId : inferenceIds ) {
280- SemanticQueryBuilder .registerInferenceAsyncAction (queryRewriteContext , inferenceResultsMap , query , inferenceId );
291+ QueryBuilder rewritten = this ;
292+ if (queryRewriteContext .hasAsyncActions () == false ) {
293+ // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results
294+ // are provided by the user. Ensure that we set an empty inference results map in this case so that it is always non-null after
295+ // coordinator node rewrite.
296+ Map <FullyQualifiedInferenceId , InferenceResults > modifiedInferenceResultsMap = SemanticQueryBuilder .getInferenceResults (
297+ queryRewriteContext ,
298+ inferenceIds ,
299+ this .inferenceResultsMap ,
300+ getQuery ()
301+ );
302+
303+ if (modifiedInferenceResultsMap == this .inferenceResultsMap ) {
304+ // The inference results map is fully populated, so we can perform error checking
305+ inferenceResultsErrorCheck (modifiedInferenceResultsMap );
306+ } else {
307+ rewritten = copy (modifiedInferenceResultsMap );
281308 }
282309 }
283310
284- return copy ( inferenceResultsMap ) ;
311+ return rewritten ;
285312 }
286313
287314 private static Set <String > getInferenceIdsForFields (
@@ -360,9 +387,9 @@ private static void addToInferenceFieldsMap(Map<String, Float> inferenceFields,
360387 inferenceFields .compute (field , (k , v ) -> v == null ? weight : v * weight );
361388 }
362389
363- private static void inferenceResultsErrorCheck (Map <String , InferenceResults > inferenceResultsMap ) {
390+ private static void inferenceResultsErrorCheck (Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap ) {
364391 for (var entry : inferenceResultsMap .entrySet ()) {
365- String inferenceId = entry .getKey ();
392+ String inferenceId = entry .getKey (). inferenceId () ;
366393 InferenceResults inferenceResults = entry .getValue ();
367394
368395 if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults ) {
0 commit comments