3030import org .elasticsearch .rest .RestStatus ;
3131import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsBuilder ;
3232import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
33+ import org .elasticsearch .xpack .inference .external .action .SenderExecutableAction ;
34+ import org .elasticsearch .xpack .inference .external .http .retry .ResponseHandler ;
3335import org .elasticsearch .xpack .inference .external .http .sender .EmbeddingsInput ;
36+ import org .elasticsearch .xpack .inference .external .http .sender .GenericRequestManager ;
3437import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
3538import org .elasticsearch .xpack .inference .external .http .sender .InferenceInputs ;
3639import org .elasticsearch .xpack .inference .external .http .sender .UnifiedChatInput ;
3942import org .elasticsearch .xpack .inference .services .ServiceComponents ;
4043import org .elasticsearch .xpack .inference .services .ServiceUtils ;
4144import org .elasticsearch .xpack .inference .services .mistral .action .MistralActionCreator ;
45+ import org .elasticsearch .xpack .inference .services .mistral .completion .MistralChatCompletionModel ;
4246import org .elasticsearch .xpack .inference .services .mistral .embeddings .MistralEmbeddingsModel ;
4347import org .elasticsearch .xpack .inference .services .mistral .embeddings .MistralEmbeddingsServiceSettings ;
48+ import org .elasticsearch .xpack .inference .services .mistral .request .completion .MistralChatCompletionRequest ;
49+ import org .elasticsearch .xpack .inference .services .openai .response .OpenAiChatCompletionResponseEntity ;
4450import org .elasticsearch .xpack .inference .services .settings .DefaultSecretSettings ;
4551import org .elasticsearch .xpack .inference .services .settings .RateLimitSettings ;
4652
4753import java .util .EnumSet ;
4854import java .util .HashMap ;
4955import java .util .List ;
5056import java .util .Map ;
57+ import java .util .Set ;
5158
5259import static org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
5360import static org .elasticsearch .xpack .inference .services .ServiceUtils .createInvalidModelException ;
5663import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrDefaultEmpty ;
5764import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrThrowIfNull ;
5865import static org .elasticsearch .xpack .inference .services .ServiceUtils .throwIfNotEmptyMap ;
59- import static org .elasticsearch .xpack .inference .services .ServiceUtils .throwUnsupportedUnifiedCompletionOperation ;
6066import static org .elasticsearch .xpack .inference .services .mistral .MistralConstants .MODEL_FIELD ;
6167
68+ /**
69+ * MistralService is an implementation of the SenderService that handles inference tasks
70+ * using Mistral models. It supports text embedding, completion, and chat completion tasks.
71+ * The service uses MistralActionCreator to create actions for executing inference requests.
72+ */
6273public class MistralService extends SenderService {
6374 public static final String NAME = "mistral" ;
6475
6576 private static final String SERVICE_NAME = "Mistral" ;
66- private static final EnumSet <TaskType > supportedTaskTypes = EnumSet .of (TaskType .TEXT_EMBEDDING );
77+ private static final EnumSet <TaskType > supportedTaskTypes = EnumSet .of (
78+ TaskType .TEXT_EMBEDDING ,
79+ TaskType .COMPLETION ,
80+ TaskType .CHAT_COMPLETION
81+ );
82+ private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler (
83+ "mistral chat completions" ,
84+ OpenAiChatCompletionResponseEntity ::fromResponse
85+ );
6786
6887 public MistralService (HttpRequestSender .Factory factory , ServiceComponents serviceComponents ) {
6988 super (factory , serviceComponents );
@@ -79,11 +98,16 @@ protected void doInfer(
7998 ) {
8099 var actionCreator = new MistralActionCreator (getSender (), getServiceComponents ());
81100
82- if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel ) {
83- var action = mistralEmbeddingsModel .accept (actionCreator , taskSettings );
84- action .execute (inputs , timeout , listener );
85- } else {
86- listener .onFailure (createInvalidModelException (model ));
101+ switch (model ) {
102+ case MistralEmbeddingsModel mistralEmbeddingsModel -> {
103+ var action = mistralEmbeddingsModel .accept (actionCreator , taskSettings );
104+ action .execute (inputs , timeout , listener );
105+ }
106+ case MistralChatCompletionModel mistralChatCompletionModel -> {
107+ var action = mistralChatCompletionModel .accept (actionCreator );
108+ action .execute (inputs , timeout , listener );
109+ }
110+ default -> listener .onFailure (createInvalidModelException (model ));
87111 }
88112 }
89113
@@ -99,7 +123,24 @@ protected void doUnifiedCompletionInfer(
99123 TimeValue timeout ,
100124 ActionListener <InferenceServiceResults > listener
101125 ) {
102- throwUnsupportedUnifiedCompletionOperation (NAME );
126+ if (model instanceof MistralChatCompletionModel == false ) {
127+ listener .onFailure (createInvalidModelException (model ));
128+ return ;
129+ }
130+
131+ MistralChatCompletionModel mistralChatCompletionModel = (MistralChatCompletionModel ) model ;
132+ var overriddenModel = MistralChatCompletionModel .of (mistralChatCompletionModel , inputs .getRequest ());
133+ var manager = new GenericRequestManager <>(
134+ getServiceComponents ().threadPool (),
135+ overriddenModel ,
136+ UNIFIED_CHAT_COMPLETION_HANDLER ,
137+ unifiedChatInput -> new MistralChatCompletionRequest (unifiedChatInput , overriddenModel ),
138+ UnifiedChatInput .class
139+ );
140+ var errorMessage = MistralActionCreator .buildErrorMessage (TaskType .CHAT_COMPLETION , model .getInferenceEntityId ());
141+ var action = new SenderExecutableAction (getSender (), manager , errorMessage );
142+
143+ action .execute (inputs , timeout , listener );
103144 }
104145
105146 @ Override
@@ -162,7 +203,7 @@ public void parseRequestConfig(
162203 );
163204 }
164205
165- MistralEmbeddingsModel model = createModel (
206+ MistralModel model = createModel (
166207 modelId ,
167208 taskType ,
168209 serviceSettingsMap ,
@@ -184,7 +225,7 @@ public void parseRequestConfig(
184225 }
185226
186227 @ Override
187- public Model parsePersistedConfigWithSecrets (
228+ public MistralModel parsePersistedConfigWithSecrets (
188229 String modelId ,
189230 TaskType taskType ,
190231 Map <String , Object > config ,
@@ -211,7 +252,7 @@ public Model parsePersistedConfigWithSecrets(
211252 }
212253
213254 @ Override
214- public Model parsePersistedConfig (String modelId , TaskType taskType , Map <String , Object > config ) {
255+ public MistralModel parsePersistedConfig (String modelId , TaskType taskType , Map <String , Object > config ) {
215256 Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
216257 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
217258
@@ -236,7 +277,12 @@ public TransportVersion getMinimalSupportedVersion() {
236277 return TransportVersions .V_8_15_0 ;
237278 }
238279
239- private static MistralEmbeddingsModel createModel (
280+ @ Override
281+ public Set <TaskType > supportedStreamingTasks () {
282+ return EnumSet .of (TaskType .COMPLETION , TaskType .CHAT_COMPLETION );
283+ }
284+
285+ private static MistralModel createModel (
240286 String modelId ,
241287 TaskType taskType ,
242288 Map <String , Object > serviceSettings ,
@@ -246,8 +292,8 @@ private static MistralEmbeddingsModel createModel(
246292 String failureMessage ,
247293 ConfigurationParseContext context
248294 ) {
249- if (taskType == TaskType . TEXT_EMBEDDING ) {
250- return new MistralEmbeddingsModel (
295+ return switch (taskType ) {
296+ case TEXT_EMBEDDING -> new MistralEmbeddingsModel (
251297 modelId ,
252298 taskType ,
253299 NAME ,
@@ -257,12 +303,19 @@ private static MistralEmbeddingsModel createModel(
257303 secretSettings ,
258304 context
259305 );
260- }
261-
262- throw new ElasticsearchStatusException (failureMessage , RestStatus .BAD_REQUEST );
306+ case CHAT_COMPLETION , COMPLETION -> new MistralChatCompletionModel (
307+ modelId ,
308+ taskType ,
309+ NAME ,
310+ serviceSettings ,
311+ secretSettings ,
312+ context
313+ );
314+ default -> throw new ElasticsearchStatusException (failureMessage , RestStatus .BAD_REQUEST );
315+ };
263316 }
264317
265- private MistralEmbeddingsModel createModelFromPersistent (
318+ private MistralModel createModelFromPersistent (
266319 String inferenceEntityId ,
267320 TaskType taskType ,
268321 Map <String , Object > serviceSettings ,
@@ -284,7 +337,7 @@ private MistralEmbeddingsModel createModelFromPersistent(
284337 }
285338
286339 @ Override
287- public Model updateModelWithEmbeddingDetails (Model model , int embeddingSize ) {
340+ public MistralEmbeddingsModel updateModelWithEmbeddingDetails (Model model , int embeddingSize ) {
288341 if (model instanceof MistralEmbeddingsModel embeddingsModel ) {
289342 var serviceSettings = embeddingsModel .getServiceSettings ();
290343
@@ -304,6 +357,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
304357 }
305358 }
306359
360+ /**
361+ * Configuration class for the Mistral inference service.
362+ * It provides the settings and configurations required for the service.
363+ */
307364 public static class Configuration {
308365 public static InferenceServiceConfiguration get () {
309366 return configuration .getOrCompute ();
0 commit comments