|
39 | 39 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; |
40 | 40 | import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; |
41 | 41 | import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; |
| 42 | +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; |
42 | 43 | import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; |
43 | 44 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; |
44 | 45 | import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; |
|
72 | 73 | import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; |
73 | 74 | import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; |
74 | 75 | import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage; |
| 76 | +import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.USER_ROLE; |
75 | 77 |
|
76 | 78 | public class ElasticInferenceService extends SenderService { |
77 | 79 |
|
@@ -164,7 +166,8 @@ protected void doUnifiedCompletionInfer( |
164 | 166 | TimeValue timeout, |
165 | 167 | ActionListener<InferenceServiceResults> listener |
166 | 168 | ) { |
167 | | - if (model instanceof ElasticInferenceServiceCompletionModel == false || model.getTaskType() != TaskType.CHAT_COMPLETION) { |
| 169 | + if (model instanceof ElasticInferenceServiceCompletionModel == false |
| 170 | + || (model.getTaskType() != TaskType.CHAT_COMPLETION && model.getTaskType() != TaskType.COMPLETION)) { |
168 | 171 | listener.onFailure(createInvalidModelException(model)); |
169 | 172 | return; |
170 | 173 | } |
@@ -214,10 +217,17 @@ protected void doInfer( |
214 | 217 |
|
215 | 218 | var elasticInferenceServiceModel = (ElasticInferenceServiceModel) model; |
216 | 219 |
|
| 220 | + // For ElasticInferenceServiceCompletionModel, convert ChatCompletionInput to UnifiedChatInput |
| 221 | + // since the request manager expects UnifiedChatInput |
| 222 | + final InferenceInputs finalInputs = (elasticInferenceServiceModel instanceof ElasticInferenceServiceCompletionModel |
| 223 | + && inputs instanceof ChatCompletionInput) |
| 224 | + ? new UnifiedChatInput((ChatCompletionInput) inputs, USER_ROLE) |
| 225 | + : inputs; |
| 226 | + |
217 | 227 | actionCreator.create( |
218 | 228 | elasticInferenceServiceModel, |
219 | 229 | currentTraceInfo, |
220 | | - listener.delegateFailureAndWrap((delegate, action) -> action.execute(inputs, timeout, delegate)) |
| 230 | + listener.delegateFailureAndWrap((delegate, action) -> action.execute(finalInputs, timeout, delegate)) |
221 | 231 | ); |
222 | 232 | } |
223 | 233 |
|
|
0 commit comments