37
37
import java .util .Map ;
38
38
import java .util .Objects ;
39
39
import java .util .Set ;
40
- import java .util .concurrent .ConcurrentHashMap ;
41
40
41
+ import static org .elasticsearch .TransportVersions .INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS ;
42
42
import static org .elasticsearch .index .IndexSettings .DEFAULT_FIELD_SETTING ;
43
+ import static org .elasticsearch .transport .RemoteClusterAware .LOCAL_CLUSTER_GROUP_KEY ;
44
+ import static org .elasticsearch .xpack .inference .queries .SemanticQueryBuilder .convertFromBwcInferenceResultsMap ;
43
45
44
46
/**
45
47
* <p>
@@ -60,7 +62,7 @@ public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBu
60
62
public static final NodeFeature NEW_SEMANTIC_QUERY_INTERCEPTORS = new NodeFeature ("search.new_semantic_query_interceptors" );
61
63
62
64
protected final T originalQuery ;
63
- protected final Map <String , InferenceResults > inferenceResultsMap ;
65
+ protected final Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap ;
64
66
65
67
protected InterceptedInferenceQueryBuilder (T originalQuery ) {
66
68
Objects .requireNonNull (originalQuery , "original query must not be null" );
@@ -72,12 +74,20 @@ protected InterceptedInferenceQueryBuilder(T originalQuery) {
72
74
protected InterceptedInferenceQueryBuilder (StreamInput in ) throws IOException {
73
75
super (in );
74
76
this .originalQuery = (T ) in .readNamedWriteable (QueryBuilder .class );
75
- this .inferenceResultsMap = in .readOptional (i1 -> i1 .readImmutableMap (i2 -> i2 .readNamedWriteable (InferenceResults .class )));
77
+ if (in .getTransportVersion ().supports (INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS )) {
78
+ this .inferenceResultsMap = in .readOptional (
79
+ i1 -> i1 .readImmutableMap (FullyQualifiedInferenceId ::new , i2 -> i2 .readNamedWriteable (InferenceResults .class ))
80
+ );
81
+ } else {
82
+ this .inferenceResultsMap = convertFromBwcInferenceResultsMap (
83
+ in .readOptional (i1 -> i1 .readImmutableMap (i2 -> i2 .readNamedWriteable (InferenceResults .class )))
84
+ );
85
+ }
76
86
}
77
87
78
88
protected InterceptedInferenceQueryBuilder (
79
89
InterceptedInferenceQueryBuilder <T > other ,
80
- Map <String , InferenceResults > inferenceResultsMap
90
+ Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap
81
91
) {
82
92
this .originalQuery = other .originalQuery ;
83
93
this .inferenceResultsMap = inferenceResultsMap ;
@@ -122,7 +132,7 @@ protected InterceptedInferenceQueryBuilder(
122
132
* @param inferenceResultsMap The inference results map
123
133
* @return A copy of {@code this} with the provided inference results map
124
134
*/
125
- protected abstract QueryBuilder copy (Map <String , InferenceResults > inferenceResultsMap );
135
+ protected abstract QueryBuilder copy (Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap );
126
136
127
137
/**
128
138
* Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use
@@ -168,7 +178,19 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {}
168
178
@ Override
169
179
protected void doWriteTo (StreamOutput out ) throws IOException {
170
180
out .writeNamedWriteable (originalQuery );
171
- out .writeOptional ((o , v ) -> o .writeMap (v , StreamOutput ::writeNamedWriteable ), inferenceResultsMap );
181
+ if (out .getTransportVersion ().supports (INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS )) {
182
+ out .writeOptional (
183
+ (o , v ) -> o .writeMap (v , StreamOutput ::writeWriteable , StreamOutput ::writeNamedWriteable ),
184
+ inferenceResultsMap
185
+ );
186
+ } else {
187
+ out .writeOptional ((o1 , v ) -> o1 .writeMap (v , (o2 , id ) -> {
188
+ if (id .clusterAlias ().equals (LOCAL_CLUSTER_GROUP_KEY ) == false ) {
189
+ throw new IllegalArgumentException ("Cannot serialize remote cluster inference results in a mixed-version cluster" );
190
+ }
191
+ o2 .writeString (id .inferenceId ());
192
+ }, StreamOutput ::writeNamedWriteable ), inferenceResultsMap );
193
+ }
172
194
}
173
195
174
196
@ Override
@@ -227,11 +249,6 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex
227
249
}
228
250
229
251
private QueryBuilder doRewriteGetInferenceResults (QueryRewriteContext queryRewriteContext ) {
230
- if (this .inferenceResultsMap != null ) {
231
- inferenceResultsErrorCheck (this .inferenceResultsMap );
232
- return this ;
233
- }
234
-
235
252
QueryBuilder rewrittenBwC = doRewriteBwC (queryRewriteContext );
236
253
if (rewrittenBwC != this ) {
237
254
return rewrittenBwC ;
@@ -271,17 +288,27 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
271
288
inferenceIds = Set .of (inferenceIdOverride );
272
289
}
273
290
274
- // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results are
275
- // provided by the user.
276
- String query = getQuery ();
277
- Map <String , InferenceResults > inferenceResultsMap = new ConcurrentHashMap <>();
278
- if (query != null ) {
279
- for (String inferenceId : inferenceIds ) {
280
- SemanticQueryBuilder .registerInferenceAsyncAction (queryRewriteContext , inferenceResultsMap , query , inferenceId );
291
+ QueryBuilder rewritten = this ;
292
+ if (queryRewriteContext .hasAsyncActions () == false ) {
293
+ // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results
294
+ // are provided by the user. Ensure that we set an empty inference results map in this case so that it is always non-null after
295
+ // coordinator node rewrite.
296
+ Map <FullyQualifiedInferenceId , InferenceResults > modifiedInferenceResultsMap = SemanticQueryBuilder .getInferenceResults (
297
+ queryRewriteContext ,
298
+ inferenceIds ,
299
+ this .inferenceResultsMap ,
300
+ getQuery ()
301
+ );
302
+
303
+ if (modifiedInferenceResultsMap == this .inferenceResultsMap ) {
304
+ // The inference results map is fully populated, so we can perform error checking
305
+ inferenceResultsErrorCheck (modifiedInferenceResultsMap );
306
+ } else {
307
+ rewritten = copy (modifiedInferenceResultsMap );
281
308
}
282
309
}
283
310
284
- return copy ( inferenceResultsMap ) ;
311
+ return rewritten ;
285
312
}
286
313
287
314
private static Set <String > getInferenceIdsForFields (
@@ -360,9 +387,9 @@ private static void addToInferenceFieldsMap(Map<String, Float> inferenceFields,
360
387
inferenceFields .compute (field , (k , v ) -> v == null ? weight : v * weight );
361
388
}
362
389
363
- private static void inferenceResultsErrorCheck (Map <String , InferenceResults > inferenceResultsMap ) {
390
+ private static void inferenceResultsErrorCheck (Map <FullyQualifiedInferenceId , InferenceResults > inferenceResultsMap ) {
364
391
for (var entry : inferenceResultsMap .entrySet ()) {
365
- String inferenceId = entry .getKey ();
392
+ String inferenceId = entry .getKey (). inferenceId () ;
366
393
InferenceResults inferenceResults = entry .getValue ();
367
394
368
395
if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults ) {
0 commit comments