Skip to content

Commit f7dc246

Browse files
Add Mistral AI Chat Completion support to Inference Plugin
1 parent 41f186d commit f7dc246

24 files changed

+920
-78
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ static TransportVersion def(int id) {
181181
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
182182
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
183183
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
184+
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_38);
184185
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
185186
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
186187
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -266,6 +267,7 @@ static TransportVersion def(int id) {
266267
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
267268
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
268269
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
270+
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_083_0_00);
269271
/*
270272
* STOP! READ THIS FIRST! No, really,
271273
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
9191
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
9292
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
93+
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
9394
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
9495
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
9596
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
@@ -217,6 +218,13 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
217218
MistralEmbeddingsServiceSettings::new
218219
)
219220
);
221+
namedWriteables.add(
222+
new NamedWriteableRegistry.Entry(
223+
ServiceSettings.class,
224+
MistralChatCompletionServiceSettings.NAME,
225+
MistralChatCompletionServiceSettings::new
226+
)
227+
);
220228

221229
// note - no task settings for Mistral embeddings...
222230
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import java.io.IOException;
1616
import java.util.Objects;
1717

18+
/**
19+
* Represents a unified chat completion request entity.
20+
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
21+
*/
1822
public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
1923

2024
public static final String NAME_FIELD = "name";
@@ -162,11 +166,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
162166

163167
builder.field(STREAM_FIELD, stream);
164168
if (stream) {
165-
builder.startObject(STREAM_OPTIONS_FIELD);
166-
builder.field(INCLUDE_USAGE_FIELD, true);
167-
builder.endObject();
169+
fillStreamOptionsFields(builder);
168170
}
169171

170172
return builder;
171173
}
174+
175+
/**
176+
* This method is used to fill the stream options fields in the request entity.
177+
* It is called when the stream option is set to true.
178+
*/
179+
protected void fillStreamOptionsFields(XContentBuilder builder) throws IOException {
180+
builder.startObject(STREAM_OPTIONS_FIELD);
181+
builder.field(INCLUDE_USAGE_FIELD, true);
182+
builder.endObject();
183+
}
172184
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
public class MistralConstants {
1111
public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings";
12+
public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions";
1213

1314
// note - there is no bounds information available from Mistral,
1415
// so we'll use a sane default here which is the same as Cohere's
@@ -18,4 +19,8 @@ public class MistralConstants {
1819
public static final String MODEL_FIELD = "model";
1920
public static final String INPUT_FIELD = "input";
2021
public static final String ENCODING_FORMAT_FIELD = "encoding_format";
22+
public static final String MAX_TOKENS_FIELD = "max_tokens";
23+
public static final String DETAIL_FIELD = "detail";
24+
public static final String MSG_FIELD = "msg";
25+
public static final String MESSAGE_FIELD = "message";
2126
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
2323
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
2424
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
25-
import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest;
25+
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
2626
import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity;
2727

2828
import java.util.List;
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.mistral;
9+
10+
import org.elasticsearch.inference.ModelConfigurations;
11+
import org.elasticsearch.inference.ModelSecrets;
12+
import org.elasticsearch.inference.ServiceSettings;
13+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
14+
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
15+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
16+
17+
import java.net.URI;
18+
import java.net.URISyntaxException;
19+
20+
/**
21+
* Represents a Mistral model that can be used for inference tasks.
22+
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
23+
*/
24+
public abstract class MistralModel extends RateLimitGroupingModel {
25+
protected String model;
26+
protected URI uri;
27+
protected RateLimitSettings rateLimitSettings;
28+
29+
protected MistralModel(ModelConfigurations configurations, ModelSecrets secrets) {
30+
super(configurations, secrets);
31+
}
32+
33+
protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
34+
super(model, serviceSettings);
35+
}
36+
37+
public String model() {
38+
return this.model;
39+
}
40+
41+
public URI uri() {
42+
return this.uri;
43+
}
44+
45+
@Override
46+
public RateLimitSettings rateLimitSettings() {
47+
return this.rateLimitSettings;
48+
}
49+
50+
@Override
51+
public int rateLimitGroupingHash() {
52+
return 0;
53+
}
54+
55+
// Needed for testing only
56+
public void setURI(String newUri) {
57+
try {
58+
this.uri = new URI(newUri);
59+
} catch (URISyntaxException e) {
60+
// swallow any error
61+
}
62+
}
63+
64+
@Override
65+
public DefaultSecretSettings getSecretSettings() {
66+
return (DefaultSecretSettings) super.getSecretSettings();
67+
}
68+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
import org.elasticsearch.rest.RestStatus;
3131
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3232
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
33+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
34+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3335
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
36+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3437
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3538
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3639
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -39,15 +42,19 @@
3942
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4043
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4144
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator;
45+
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
4246
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
4347
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
48+
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
49+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
4450
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4551
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4652

4753
import java.util.EnumSet;
4854
import java.util.HashMap;
4955
import java.util.List;
5056
import java.util.Map;
57+
import java.util.Set;
5158

5259
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
5360
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
@@ -56,14 +63,26 @@
5663
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5764
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5865
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
59-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6066
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
6167

68+
/**
69+
* MistralService is an implementation of the SenderService that handles inference tasks
70+
* using Mistral models. It supports text embedding, completion, and chat completion tasks.
71+
* The service uses MistralActionCreator to create actions for executing inference requests.
72+
*/
6273
public class MistralService extends SenderService {
6374
public static final String NAME = "mistral";
6475

6576
private static final String SERVICE_NAME = "Mistral";
66-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
77+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
78+
TaskType.TEXT_EMBEDDING,
79+
TaskType.COMPLETION,
80+
TaskType.CHAT_COMPLETION
81+
);
82+
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
83+
"mistral chat completions",
84+
OpenAiChatCompletionResponseEntity::fromResponse
85+
);
6786

6887
public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
6988
super(factory, serviceComponents);
@@ -79,11 +98,16 @@ protected void doInfer(
7998
) {
8099
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
81100

82-
if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
83-
var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
84-
action.execute(inputs, timeout, listener);
85-
} else {
86-
listener.onFailure(createInvalidModelException(model));
101+
switch (model) {
102+
case MistralEmbeddingsModel mistralEmbeddingsModel -> {
103+
var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
104+
action.execute(inputs, timeout, listener);
105+
}
106+
case MistralChatCompletionModel mistralChatCompletionModel -> {
107+
var action = mistralChatCompletionModel.accept(actionCreator);
108+
action.execute(inputs, timeout, listener);
109+
}
110+
default -> listener.onFailure(createInvalidModelException(model));
87111
}
88112
}
89113

@@ -99,7 +123,24 @@ protected void doUnifiedCompletionInfer(
99123
TimeValue timeout,
100124
ActionListener<InferenceServiceResults> listener
101125
) {
102-
throwUnsupportedUnifiedCompletionOperation(NAME);
126+
if (model instanceof MistralChatCompletionModel == false) {
127+
listener.onFailure(createInvalidModelException(model));
128+
return;
129+
}
130+
131+
MistralChatCompletionModel mistralChatCompletionModel = (MistralChatCompletionModel) model;
132+
var overriddenModel = MistralChatCompletionModel.of(mistralChatCompletionModel, inputs.getRequest());
133+
var manager = new GenericRequestManager<>(
134+
getServiceComponents().threadPool(),
135+
overriddenModel,
136+
UNIFIED_CHAT_COMPLETION_HANDLER,
137+
unifiedChatInput -> new MistralChatCompletionRequest(unifiedChatInput, overriddenModel),
138+
UnifiedChatInput.class
139+
);
140+
var errorMessage = MistralActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
141+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
142+
143+
action.execute(inputs, timeout, listener);
103144
}
104145

105146
@Override
@@ -162,7 +203,7 @@ public void parseRequestConfig(
162203
);
163204
}
164205

165-
MistralEmbeddingsModel model = createModel(
206+
MistralModel model = createModel(
166207
modelId,
167208
taskType,
168209
serviceSettingsMap,
@@ -184,7 +225,7 @@ public void parseRequestConfig(
184225
}
185226

186227
@Override
187-
public Model parsePersistedConfigWithSecrets(
228+
public MistralModel parsePersistedConfigWithSecrets(
188229
String modelId,
189230
TaskType taskType,
190231
Map<String, Object> config,
@@ -211,7 +252,7 @@ public Model parsePersistedConfigWithSecrets(
211252
}
212253

213254
@Override
214-
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
255+
public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
215256
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
216257
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
217258

@@ -236,7 +277,12 @@ public TransportVersion getMinimalSupportedVersion() {
236277
return TransportVersions.V_8_15_0;
237278
}
238279

239-
private static MistralEmbeddingsModel createModel(
280+
@Override
281+
public Set<TaskType> supportedStreamingTasks() {
282+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
283+
}
284+
285+
private static MistralModel createModel(
240286
String modelId,
241287
TaskType taskType,
242288
Map<String, Object> serviceSettings,
@@ -246,8 +292,8 @@ private static MistralEmbeddingsModel createModel(
246292
String failureMessage,
247293
ConfigurationParseContext context
248294
) {
249-
if (taskType == TaskType.TEXT_EMBEDDING) {
250-
return new MistralEmbeddingsModel(
295+
return switch (taskType) {
296+
case TEXT_EMBEDDING -> new MistralEmbeddingsModel(
251297
modelId,
252298
taskType,
253299
NAME,
@@ -257,12 +303,19 @@ private static MistralEmbeddingsModel createModel(
257303
secretSettings,
258304
context
259305
);
260-
}
261-
262-
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
306+
case CHAT_COMPLETION, COMPLETION -> new MistralChatCompletionModel(
307+
modelId,
308+
taskType,
309+
NAME,
310+
serviceSettings,
311+
secretSettings,
312+
context
313+
);
314+
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
315+
};
263316
}
264317

265-
private MistralEmbeddingsModel createModelFromPersistent(
318+
private MistralModel createModelFromPersistent(
266319
String inferenceEntityId,
267320
TaskType taskType,
268321
Map<String, Object> serviceSettings,
@@ -284,7 +337,7 @@ private MistralEmbeddingsModel createModelFromPersistent(
284337
}
285338

286339
@Override
287-
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
340+
public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
288341
if (model instanceof MistralEmbeddingsModel embeddingsModel) {
289342
var serviceSettings = embeddingsModel.getServiceSettings();
290343

@@ -304,6 +357,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
304357
}
305358
}
306359

360+
/**
361+
* Configuration class for the Mistral inference service.
362+
* It provides the settings and configurations required for the service.
363+
*/
307364
public static class Configuration {
308365
public static InferenceServiceConfiguration get() {
309366
return configuration.getOrCompute();

0 commit comments

Comments
 (0)