Skip to content

Commit 5512000

Browse files
committed
Use a map instead of InferenceResultsProvider
1 parent c7b5390 commit 5512000

File tree

2 files changed

+66
-30
lines changed

2 files changed

+66
-30
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import java.util.Map;
4646
import java.util.Objects;
4747
import java.util.Set;
48+
import java.util.concurrent.ConcurrentHashMap;
4849

4950
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
5051
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -56,6 +57,9 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
5657

5758
public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature("semantic_query.multiple_inference_ids");
5859

60+
// Use a placeholder inference ID that will never overlap with a real inference endpoint (user-created or internal)
61+
private static final String PLACEHOLDER_INFERENCE_ID = "$PLACEHOLDER";
62+
5963
private static final ParseField FIELD_FIELD = new ParseField("field");
6064
private static final ParseField QUERY_FIELD = new ParseField("query");
6165
private static final ParseField LENIENT_FIELD = new ParseField("lenient");
@@ -75,7 +79,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
7579

7680
private final String fieldName;
7781
private final String query;
78-
private final InferenceResultsProvider inferenceResultsProvider;
82+
private final Map<String, InferenceResults> inferenceResultsMap;
7983
private final Boolean lenient;
8084

8185
public SemanticQueryBuilder(String fieldName, String query) {
@@ -86,7 +90,7 @@ public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) {
8690
this(fieldName, query, lenient, null);
8791
}
8892

89-
protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, InferenceResultsProvider inferenceResultsProvider) {
93+
protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, Map<String, InferenceResults> inferenceResultsMap) {
9094
if (fieldName == null) {
9195
throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
9296
}
@@ -95,7 +99,7 @@ protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient,
9599
}
96100
this.fieldName = fieldName;
97101
this.query = query;
98-
this.inferenceResultsProvider = inferenceResultsProvider;
102+
this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null;
99103
this.lenient = lenient;
100104
}
101105

@@ -104,10 +108,10 @@ public SemanticQueryBuilder(StreamInput in) throws IOException {
104108
this.fieldName = in.readString();
105109
this.query = in.readString();
106110
if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) {
107-
this.inferenceResultsProvider = in.readOptionalNamedWriteable(InferenceResultsProvider.class);
111+
this.inferenceResultsMap = in.readOptional(i1 -> i1.readImmutableMap(i2 -> i2.readNamedWriteable(InferenceResults.class)));
108112
} else {
109113
InferenceResults inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class);
110-
this.inferenceResultsProvider = inferenceResults != null ? new SingleInferenceResultsProvider(inferenceResults) : null;
114+
this.inferenceResultsMap = inferenceResults != null ? buildBwcInferenceResultsMap(inferenceResults) : null;
111115
in.readBoolean(); // Discard noInferenceResults, it is no longer necessary
112116
}
113117
if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) {
@@ -122,15 +126,14 @@ protected void doWriteTo(StreamOutput out) throws IOException {
122126
out.writeString(fieldName);
123127
out.writeString(query);
124128
if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) {
125-
out.writeOptionalNamedWriteable(inferenceResultsProvider);
129+
out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeNamedWriteable), inferenceResultsMap);
126130
} else {
127131
InferenceResults inferenceResults = null;
128-
if (inferenceResultsProvider != null) {
129-
Collection<InferenceResults> allInferenceResults = inferenceResultsProvider.getAllInferenceResults();
130-
if (allInferenceResults.size() > 1) {
132+
if (inferenceResultsMap != null) {
133+
if (inferenceResultsMap.size() > 1) {
131134
throw new IllegalArgumentException("Cannot query multiple inference IDs in a mixed-version cluster");
132-
} else if (allInferenceResults.size() == 1) {
133-
inferenceResults = allInferenceResults.iterator().next();
135+
} else if (inferenceResultsMap.size() == 1) {
136+
inferenceResults = inferenceResultsMap.values().iterator().next();
134137
}
135138
}
136139

@@ -142,12 +145,13 @@ protected void doWriteTo(StreamOutput out) throws IOException {
142145
}
143146
}
144147

145-
private SemanticQueryBuilder(SemanticQueryBuilder other, InferenceResultsProvider inferenceResultsProvider) {
148+
private SemanticQueryBuilder(SemanticQueryBuilder other, Map<String, InferenceResults> inferenceResultsMap) {
146149
this.fieldName = other.fieldName;
147150
this.query = other.query;
148151
this.boost = other.boost;
149152
this.queryName = other.queryName;
150-
this.inferenceResultsProvider = inferenceResultsProvider;
153+
// No need to copy the map here since this is only called internally. We can safely assume that the caller will not modify the map.
154+
this.inferenceResultsMap = inferenceResultsMap;
151155
this.lenient = other.lenient;
152156
}
153157

@@ -173,6 +177,27 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO
173177
return PARSER.apply(parser, null);
174178
}
175179

180+
/**
181+
* Build an inference results map to store a single inference result that is not associated with an inference ID.
182+
*
183+
* @param inferenceResults The inference result
184+
* @return An inference results map
185+
*/
186+
protected static Map<String, InferenceResults> buildBwcInferenceResultsMap(InferenceResults inferenceResults) {
187+
return Map.of(PLACEHOLDER_INFERENCE_ID, inferenceResults);
188+
}
189+
190+
/**
191+
* Extract an inference result not associated with an inference ID from an inference results map. Returns null if no such inference
192+
* result exists in the map.
193+
*
194+
* @param inferenceResultsMap The inference results map
195+
* @return The inference result
196+
*/
197+
private static InferenceResults getBwcInferenceResults(Map<String, InferenceResults> inferenceResultsMap) {
198+
return inferenceResultsMap.get(PLACEHOLDER_INFERENCE_ID);
199+
}
200+
176201
@Override
177202
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
178203
builder.startObject(NAME);
@@ -200,15 +225,19 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
200225
if (fieldType == null) {
201226
return new MatchNoneQueryBuilder();
202227
} else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
203-
if (inferenceResultsProvider == null) {
228+
if (inferenceResultsMap == null) {
204229
// This should never happen, but throw on it in case it ever does
205230
throw new IllegalStateException(
206231
"No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + fieldName + "]"
207232
);
208233
}
209234

210235
String inferenceId = semanticTextFieldType.getSearchInferenceId();
211-
InferenceResults inferenceResults = inferenceResultsProvider.getInferenceResults(inferenceId);
236+
InferenceResults inferenceResults = getBwcInferenceResults(inferenceResultsMap);
237+
if (inferenceResults == null) {
238+
inferenceResults = inferenceResultsMap.get(inferenceId);
239+
}
240+
212241
return switch (inferenceResults) {
213242
case null -> throw new IllegalStateException(
214243
"No inference results set for ["
@@ -248,7 +277,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
248277
}
249278

250279
private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
251-
if (inferenceResultsProvider != null) {
280+
if (inferenceResultsMap != null) {
252281
return this;
253282
}
254283

@@ -261,7 +290,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
261290
throw new IllegalArgumentException(NAME + " query does not support cross-cluster search");
262291
}
263292

264-
MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider();
293+
Map<String, InferenceResults> inferenceResultsMap = new ConcurrentHashMap<>();
265294
Set<String> inferenceIds = getInferenceIdsForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), fieldName);
266295
for (String inferenceId : inferenceIds) {
267296
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
@@ -284,7 +313,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
284313
InferenceAction.INSTANCE,
285314
inferenceRequest,
286315
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
287-
mapInferenceResultsProvider.addInferenceResults(
316+
inferenceResultsMap.put(
288317
inferenceId,
289318
validateAndConvertInferenceResults(inferenceResponse.getResults(), fieldName, inferenceId)
290319
);
@@ -294,7 +323,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
294323
);
295324
}
296325

297-
return new SemanticQueryBuilder(this, mapInferenceResultsProvider);
326+
return new SemanticQueryBuilder(this, inferenceResultsMap);
298327
}
299328

300329
private static InferenceResults validateAndConvertInferenceResults(
@@ -371,11 +400,11 @@ private static Set<String> getInferenceIdsForForField(Collection<IndexMetadata>
371400
protected boolean doEquals(SemanticQueryBuilder other) {
372401
return Objects.equals(fieldName, other.fieldName)
373402
&& Objects.equals(query, other.query)
374-
&& Objects.equals(inferenceResultsProvider, other.inferenceResultsProvider);
403+
&& Objects.equals(inferenceResultsMap, other.inferenceResultsMap);
375404
}
376405

377406
@Override
378407
protected int doHashCode() {
379-
return Objects.hash(fieldName, query, inferenceResultsProvider);
408+
return Objects.hash(fieldName, query, inferenceResultsMap);
380409
}
381410
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
import java.util.ArrayList;
8888
import java.util.Arrays;
8989
import java.util.Collection;
90+
import java.util.HashMap;
9091
import java.util.List;
9192
import java.util.Map;
9293
import java.util.function.Supplier;
@@ -387,12 +388,18 @@ public void testSerializationBwc() throws IOException {
387388
String fieldName = randomAlphaOfLength(5);
388389
String query = randomAlphaOfLength(5);
389390

390-
MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider();
391-
mapInferenceResultsProvider.addInferenceResults(randomAlphaOfLength(5), inferenceResults);
392-
SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(fieldName, query, null, mapInferenceResultsProvider);
393-
394-
SingleInferenceResultsProvider singleInferenceResultsProvider = new SingleInferenceResultsProvider(inferenceResults);
395-
SemanticQueryBuilder bwcQuery = new SemanticQueryBuilder(fieldName, query, null, singleInferenceResultsProvider);
391+
SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(
392+
fieldName,
393+
query,
394+
null,
395+
Map.of(randomAlphaOfLength(5), inferenceResults)
396+
);
397+
SemanticQueryBuilder bwcQuery = new SemanticQueryBuilder(
398+
fieldName,
399+
query,
400+
null,
401+
SemanticQueryBuilder.buildBwcInferenceResultsMap(inferenceResults)
402+
);
396403

397404
try (BytesStreamOutput output = new BytesStreamOutput()) {
398405
output.setTransportVersion(version);
@@ -422,13 +429,13 @@ public void testSerializationBwc() throws IOException {
422429
CheckedBiConsumer<List<InferenceResults>, TransportVersion, IOException> assertMultipleInferenceResults = (
423430
inferenceResultsList,
424431
version) -> {
425-
MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider();
426-
inferenceResultsList.forEach(result -> mapInferenceResultsProvider.addInferenceResults(randomAlphaOfLength(5), result));
432+
Map<String, InferenceResults> inferenceResultsMap = new HashMap<>(inferenceResultsList.size());
433+
inferenceResultsList.forEach(result -> inferenceResultsMap.put(randomAlphaOfLength(5), result));
427434
SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(
428435
randomAlphaOfLength(5),
429436
randomAlphaOfLength(5),
430437
null,
431-
mapInferenceResultsProvider
438+
inferenceResultsMap
432439
);
433440

434441
try (BytesStreamOutput output = new BytesStreamOutput()) {

0 commit comments

Comments
 (0)