Skip to content

Commit a69b474

Browse files
committed
Check for inference errors on the coordinator node
1 parent ac844f2 commit a69b474

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
238238
inferenceResults = inferenceResultsMap.get(inferenceId);
239239
}
240240

241-
return switch (inferenceResults) {
242-
case null -> throw new IllegalStateException(
241+
if (inferenceResults == null) {
242+
throw new IllegalStateException(
243243
"No inference results set for ["
244244
+ semanticTextFieldType.typeName()
245245
+ "] field ["
@@ -248,25 +248,9 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
248248
+ inferenceId
249249
+ "]"
250250
);
251-
case ErrorInferenceResults errorInferenceResults -> throw new InferenceException(
252-
"Field [" + fieldName + "] with inference ID [" + inferenceId + "] query inference error",
253-
errorInferenceResults.getException()
254-
); // Use InferenceException here so that the status code is set by the cause
255-
case WarningInferenceResults warningInferenceResults -> throw new IllegalStateException(
256-
"Field ["
257-
+ fieldName
258-
+ "] with inference ID ["
259-
+ inferenceId
260-
+ "] query inference warning: "
261-
+ warningInferenceResults.getWarning()
262-
);
263-
default -> semanticTextFieldType.semanticQuery(
264-
inferenceResults,
265-
searchExecutionContext.requestSize(),
266-
boost(),
267-
queryName()
268-
);
269-
};
251+
}
252+
253+
return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName());
270254
} else if (lenient != null && lenient) {
271255
return new MatchNoneQueryBuilder();
272256
} else {
@@ -278,6 +262,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
278262

279263
private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
280264
if (inferenceResultsMap != null) {
265+
inferenceResultsErrorCheck();
281266
return this;
282267
}
283268

@@ -378,6 +363,30 @@ private static InferenceResults validateAndConvertInferenceResults(
378363
return inferenceResults;
379364
}
380365

366+
private void inferenceResultsErrorCheck() {
367+
for (var entry : inferenceResultsMap.entrySet()) {
368+
String inferenceId = entry.getKey();
369+
InferenceResults inferenceResults = entry.getValue();
370+
371+
if (inferenceResults instanceof ErrorInferenceResults errorInferenceResults) {
372+
// Use InferenceException here so that the status code is set by the cause
373+
throw new InferenceException(
374+
"Field [" + fieldName + "] with inference ID [" + inferenceId + "] query inference error",
375+
errorInferenceResults.getException()
376+
);
377+
} else if (inferenceResults instanceof WarningInferenceResults warningInferenceResults) {
378+
throw new IllegalStateException(
379+
"Field ["
380+
+ fieldName
381+
+ "] with inference ID ["
382+
+ inferenceId
383+
+ "] query inference warning: "
384+
+ warningInferenceResults.getWarning()
385+
);
386+
}
387+
}
388+
}
389+
381390
@Override
382391
protected Query doToQuery(SearchExecutionContext context) throws IOException {
383392
throw new IllegalStateException(NAME + " should have been rewritten to another query type");

0 commit comments

Comments
 (0)