4545import java .util .Map ;
4646import java .util .Objects ;
4747import java .util .Set ;
48+ import java .util .concurrent .ConcurrentHashMap ;
4849
4950import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
5051import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
@@ -56,6 +57,9 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
5657
5758 public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature ("semantic_query.multiple_inference_ids" );
5859
60+ // Use a placeholder inference ID that will never overlap with a real inference endpoint (user-created or internal)
61+ private static final String PLACEHOLDER_INFERENCE_ID = "$PLACEHOLDER" ;
62+
5963 private static final ParseField FIELD_FIELD = new ParseField ("field" );
6064 private static final ParseField QUERY_FIELD = new ParseField ("query" );
6165 private static final ParseField LENIENT_FIELD = new ParseField ("lenient" );
@@ -75,7 +79,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
7579
7680 private final String fieldName ;
7781 private final String query ;
78- private final InferenceResultsProvider inferenceResultsProvider ;
82+ private final Map < String , InferenceResults > inferenceResultsMap ;
7983 private final Boolean lenient ;
8084
8185 public SemanticQueryBuilder (String fieldName , String query ) {
@@ -86,7 +90,7 @@ public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) {
8690 this (fieldName , query , lenient , null );
8791 }
8892
89- protected SemanticQueryBuilder (String fieldName , String query , Boolean lenient , InferenceResultsProvider inferenceResultsProvider ) {
93+ protected SemanticQueryBuilder (String fieldName , String query , Boolean lenient , Map < String , InferenceResults > inferenceResultsMap ) {
9094 if (fieldName == null ) {
9195 throw new IllegalArgumentException ("[" + NAME + "] requires a " + FIELD_FIELD .getPreferredName () + " value" );
9296 }
@@ -95,7 +99,7 @@ protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient,
9599 }
96100 this .fieldName = fieldName ;
97101 this .query = query ;
98- this .inferenceResultsProvider = inferenceResultsProvider ;
102+ this .inferenceResultsMap = inferenceResultsMap != null ? Map . copyOf ( inferenceResultsMap ) : null ;
99103 this .lenient = lenient ;
100104 }
101105
@@ -104,10 +108,10 @@ public SemanticQueryBuilder(StreamInput in) throws IOException {
104108 this .fieldName = in .readString ();
105109 this .query = in .readString ();
106110 if (in .getTransportVersion ().onOrAfter (TransportVersions .SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS )) {
107- this .inferenceResultsProvider = in .readOptionalNamedWriteable ( InferenceResultsProvider . class );
111+ this .inferenceResultsMap = in .readOptional ( i1 -> i1 . readImmutableMap ( i2 -> i2 . readNamedWriteable ( InferenceResults . class )) );
108112 } else {
109113 InferenceResults inferenceResults = in .readOptionalNamedWriteable (InferenceResults .class );
110- this .inferenceResultsProvider = inferenceResults != null ? new SingleInferenceResultsProvider (inferenceResults ) : null ;
114+ this .inferenceResultsMap = inferenceResults != null ? buildBwcInferenceResultsMap (inferenceResults ) : null ;
111115 in .readBoolean (); // Discard noInferenceResults, it is no longer necessary
112116 }
113117 if (in .getTransportVersion ().onOrAfter (TransportVersions .SEMANTIC_QUERY_LENIENT )) {
@@ -122,15 +126,14 @@ protected void doWriteTo(StreamOutput out) throws IOException {
122126 out .writeString (fieldName );
123127 out .writeString (query );
124128 if (out .getTransportVersion ().onOrAfter (TransportVersions .SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS )) {
125- out .writeOptionalNamedWriteable ( inferenceResultsProvider );
129+ out .writeOptional (( o , v ) -> o . writeMap ( v , StreamOutput :: writeNamedWriteable ), inferenceResultsMap );
126130 } else {
127131 InferenceResults inferenceResults = null ;
128- if (inferenceResultsProvider != null ) {
129- Collection <InferenceResults > allInferenceResults = inferenceResultsProvider .getAllInferenceResults ();
130- if (allInferenceResults .size () > 1 ) {
132+ if (inferenceResultsMap != null ) {
133+ if (inferenceResultsMap .size () > 1 ) {
131134 throw new IllegalArgumentException ("Cannot query multiple inference IDs in a mixed-version cluster" );
132- } else if (allInferenceResults .size () == 1 ) {
133- inferenceResults = allInferenceResults .iterator ().next ();
135+ } else if (inferenceResultsMap .size () == 1 ) {
136+ inferenceResults = inferenceResultsMap . values () .iterator ().next ();
134137 }
135138 }
136139
@@ -142,12 +145,13 @@ protected void doWriteTo(StreamOutput out) throws IOException {
142145 }
143146 }
144147
145- private SemanticQueryBuilder (SemanticQueryBuilder other , InferenceResultsProvider inferenceResultsProvider ) {
148+ private SemanticQueryBuilder (SemanticQueryBuilder other , Map < String , InferenceResults > inferenceResultsMap ) {
146149 this .fieldName = other .fieldName ;
147150 this .query = other .query ;
148151 this .boost = other .boost ;
149152 this .queryName = other .queryName ;
150- this .inferenceResultsProvider = inferenceResultsProvider ;
153+ // No need to copy the map here since this is only called internally. We can safely assume that the caller will not modify the map.
154+ this .inferenceResultsMap = inferenceResultsMap ;
151155 this .lenient = other .lenient ;
152156 }
153157
@@ -173,6 +177,27 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO
173177 return PARSER .apply (parser , null );
174178 }
175179
180+ /**
181+ * Build an inference results map to store a single inference result that is not associated with an inference ID.
182+ *
183+ * @param inferenceResults The inference result
184+ * @return An inference results map
185+ */
186+ protected static Map <String , InferenceResults > buildBwcInferenceResultsMap (InferenceResults inferenceResults ) {
187+ return Map .of (PLACEHOLDER_INFERENCE_ID , inferenceResults );
188+ }
189+
190+ /**
191+ * Extract an inference result not associated with an inference ID from an inference results map. Returns null if no such inference
192+ * result exists in the map.
193+ *
194+ * @param inferenceResultsMap The inference results map
195+ * @return The inference result
196+ */
197+ private static InferenceResults getBwcInferenceResults (Map <String , InferenceResults > inferenceResultsMap ) {
198+ return inferenceResultsMap .get (PLACEHOLDER_INFERENCE_ID );
199+ }
200+
176201 @ Override
177202 protected void doXContent (XContentBuilder builder , Params params ) throws IOException {
178203 builder .startObject (NAME );
@@ -200,15 +225,19 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
200225 if (fieldType == null ) {
201226 return new MatchNoneQueryBuilder ();
202227 } else if (fieldType instanceof SemanticTextFieldMapper .SemanticTextFieldType semanticTextFieldType ) {
203- if (inferenceResultsProvider == null ) {
228+ if (inferenceResultsMap == null ) {
204229 // This should never happen, but throw on it in case it ever does
205230 throw new IllegalStateException (
206231 "No inference results set for [" + semanticTextFieldType .typeName () + "] field [" + fieldName + "]"
207232 );
208233 }
209234
210235 String inferenceId = semanticTextFieldType .getSearchInferenceId ();
211- InferenceResults inferenceResults = inferenceResultsProvider .getInferenceResults (inferenceId );
236+ InferenceResults inferenceResults = getBwcInferenceResults (inferenceResultsMap );
237+ if (inferenceResults == null ) {
238+ inferenceResults = inferenceResultsMap .get (inferenceId );
239+ }
240+
212241 return switch (inferenceResults ) {
213242 case null -> throw new IllegalStateException (
214243 "No inference results set for ["
@@ -248,7 +277,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
248277 }
249278
250279 private SemanticQueryBuilder doRewriteGetInferenceResults (QueryRewriteContext queryRewriteContext ) {
251- if (inferenceResultsProvider != null ) {
280+ if (inferenceResultsMap != null ) {
252281 return this ;
253282 }
254283
@@ -261,7 +290,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
261290 throw new IllegalArgumentException (NAME + " query does not support cross-cluster search" );
262291 }
263292
264- MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider ();
293+ Map < String , InferenceResults > inferenceResultsMap = new ConcurrentHashMap <> ();
265294 Set <String > inferenceIds = getInferenceIdsForForField (resolvedIndices .getConcreteLocalIndicesMetadata ().values (), fieldName );
266295 for (String inferenceId : inferenceIds ) {
267296 InferenceAction .Request inferenceRequest = new InferenceAction .Request (
@@ -284,7 +313,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
284313 InferenceAction .INSTANCE ,
285314 inferenceRequest ,
286315 listener .delegateFailureAndWrap ((l , inferenceResponse ) -> {
287- mapInferenceResultsProvider . addInferenceResults (
316+ inferenceResultsMap . put (
288317 inferenceId ,
289318 validateAndConvertInferenceResults (inferenceResponse .getResults (), fieldName , inferenceId )
290319 );
@@ -294,7 +323,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
294323 );
295324 }
296325
297- return new SemanticQueryBuilder (this , mapInferenceResultsProvider );
326+ return new SemanticQueryBuilder (this , inferenceResultsMap );
298327 }
299328
300329 private static InferenceResults validateAndConvertInferenceResults (
@@ -371,11 +400,11 @@ private static Set<String> getInferenceIdsForForField(Collection<IndexMetadata>
371400 protected boolean doEquals (SemanticQueryBuilder other ) {
372401 return Objects .equals (fieldName , other .fieldName )
373402 && Objects .equals (query , other .query )
374- && Objects .equals (inferenceResultsProvider , other .inferenceResultsProvider );
403+ && Objects .equals (inferenceResultsMap , other .inferenceResultsMap );
375404 }
376405
377406 @ Override
378407 protected int doHashCode () {
379- return Objects .hash (fieldName , query , inferenceResultsProvider );
408+ return Objects .hash (fieldName , query , inferenceResultsMap );
380409 }
381410}
0 commit comments