Skip to content

Commit 6e734ed

Browse files
committed
Fix regression introduced during tests.
1 parent 66e73b4 commit 6e734ed

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,12 @@ private void executeTaskImmediately(RejectableTask task) {
305305
e
306306
);
307307

308-
task.onRejection(
309-
new EsRejectedExecutionException(
310-
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()),
311-
false
312-
)
308+
var rejectionException = new EsRejectedExecutionException(
309+
format("Failed to execute request for inference id [%s]", task.getRequestManager().inferenceEntityId()),
310+
false
313311
);
312+
rejectionException.initCause(e);
313+
task.onRejection(rejectionException);
314314
}
315315
}
316316

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
4040
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
4141
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
42+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
4243
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
4344
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
4445
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -72,6 +73,7 @@
7273
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
7374
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
7475
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
76+
import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.USER_ROLE;
7577

7678
public class ElasticInferenceService extends SenderService {
7779

@@ -164,7 +166,8 @@ protected void doUnifiedCompletionInfer(
164166
TimeValue timeout,
165167
ActionListener<InferenceServiceResults> listener
166168
) {
167-
if (model instanceof ElasticInferenceServiceCompletionModel == false || model.getTaskType() != TaskType.CHAT_COMPLETION) {
169+
if (model instanceof ElasticInferenceServiceCompletionModel == false
170+
|| (model.getTaskType() != TaskType.CHAT_COMPLETION && model.getTaskType() != TaskType.COMPLETION)) {
168171
listener.onFailure(createInvalidModelException(model));
169172
return;
170173
}
@@ -214,10 +217,17 @@ protected void doInfer(
214217

215218
var elasticInferenceServiceModel = (ElasticInferenceServiceModel) model;
216219

220+
// For ElasticInferenceServiceCompletionModel, convert ChatCompletionInput to UnifiedChatInput
221+
// since the request manager expects UnifiedChatInput
222+
final InferenceInputs finalInputs = (elasticInferenceServiceModel instanceof ElasticInferenceServiceCompletionModel
223+
&& inputs instanceof ChatCompletionInput)
224+
? new UnifiedChatInput((ChatCompletionInput) inputs, USER_ROLE)
225+
: inputs;
226+
217227
actionCreator.create(
218228
elasticInferenceServiceModel,
219229
currentTraceInfo,
220-
listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate))
230+
listener.delegateFailureAndWrap((delegate, action) -> action.execute(finalInputs, timeout, delegate))
221231
);
222232
}
223233

0 commit comments

Comments
 (0)