Skip to content

Commit 91f8ccf

Browse files
Refactor Mistral chat completion integration and add tests
1 parent 69f16b3 commit 91f8ccf

File tree

9 files changed

+334
-45
lines changed

9 files changed

+334
-45
lines changed

docs/changelog/128538.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pr: 128538
2-
summary: "[ML] Add Mistral Chat Completion support to the Inference Plugin"
2+
summary: "Added Mistral Chat Completion support to the Inference Plugin"
33
area: Machine Learning
44
type: enhancement
55
issues: []

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
124124

125125
public void testGetServicesWithCompletionTaskType() throws IOException {
126126
List<Object> services = getServices(TaskType.COMPLETION);
127-
assertThat(services.size(), equalTo(12));
127+
assertThat(services.size(), equalTo(13));
128128

129129
var providers = providers(services);
130130

@@ -143,15 +143,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
143143
"openai",
144144
"streaming_completion_test_service",
145145
"hugging_face",
146-
"amazon_sagemaker"
146+
"amazon_sagemaker",
147+
"mistral"
147148
).toArray()
148149
)
149150
);
150151
}
151152

152153
public void testGetServicesWithChatCompletionTaskType() throws IOException {
153154
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
154-
assertThat(services.size(), equalTo(7));
155+
assertThat(services.size(), equalTo(8));
155156

156157
var providers = providers(services);
157158

@@ -165,7 +166,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
165166
"streaming_completion_test_service",
166167
"hugging_face",
167168
"amazon_sagemaker",
168-
"googlevertexai"
169+
"googlevertexai",
170+
"mistral"
169171
).toArray()
170172
)
171173
);
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,25 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.services.openai;
8+
package org.elasticsearch.xpack.inference.services.mistral;
99

1010
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1111
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponseEntity;
12+
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
1213

1314
/**
14-
* Handles non-streaming chat completion responses for Mistral models, extending the OpenAI chat completion response handler.
15+
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
1516
* This class is specifically designed to handle Mistral's error response format.
1617
*/
17-
public class MistralChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
18+
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
1819

1920
/**
20-
* Constructs a MistralChatCompletionResponseHandler with the specified request type and response parser.
21+
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
2122
*
22-
* @param requestType The type of request being handled (e.g., "mistral chat completions").
23+
* @param requestType The type of request being handled (e.g., "mistral completions").
2324
* @param parseFunction The function to parse the response.
2425
*/
25-
public MistralChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
26+
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
2627
super(requestType, parseFunction, MistralErrorResponseEntity::fromResponse);
2728
}
2829
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,15 @@ protected void doInfer(
9999
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
100100

101101
switch (model) {
102-
case MistralEmbeddingsModel mistralEmbeddingsModel -> {
103-
var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
104-
action.execute(inputs, timeout, listener);
105-
}
106-
case MistralChatCompletionModel mistralChatCompletionModel -> {
107-
var action = mistralChatCompletionModel.accept(actionCreator);
108-
action.execute(inputs, timeout, listener);
109-
}
110-
default -> listener.onFailure(createInvalidModelException(model));
102+
case MistralEmbeddingsModel mistralEmbeddingsModel:
103+
mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
104+
break;
105+
case MistralChatCompletionModel mistralChatCompletionModel:
106+
mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
107+
break;
108+
default:
109+
listener.onFailure(createInvalidModelException(model));
110+
break;
111111
}
112112
}
113113

@@ -292,27 +292,23 @@ private static MistralModel createModel(
292292
String failureMessage,
293293
ConfigurationParseContext context
294294
) {
295-
return switch (taskType) {
296-
case TEXT_EMBEDDING -> new MistralEmbeddingsModel(
297-
modelId,
298-
taskType,
299-
NAME,
300-
serviceSettings,
301-
taskSettings,
302-
chunkingSettings,
303-
secretSettings,
304-
context
305-
);
306-
case CHAT_COMPLETION, COMPLETION -> new MistralChatCompletionModel(
307-
modelId,
308-
taskType,
309-
NAME,
310-
serviceSettings,
311-
secretSettings,
312-
context
313-
);
314-
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
315-
};
295+
switch (taskType) {
296+
case TEXT_EMBEDDING:
297+
return new MistralEmbeddingsModel(
298+
modelId,
299+
taskType,
300+
NAME,
301+
serviceSettings,
302+
taskSettings,
303+
chunkingSettings,
304+
secretSettings,
305+
context
306+
);
307+
case CHAT_COMPLETION, COMPLETION:
308+
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
309+
default:
310+
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
311+
}
316312
}
317313

318314
private MistralModel createModelFromPersistent(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1818
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1919
import org.elasticsearch.xpack.inference.services.ServiceComponents;
20+
import org.elasticsearch.xpack.inference.services.mistral.MistralCompletionResponseHandler;
2021
import org.elasticsearch.xpack.inference.services.mistral.MistralEmbeddingsRequestManager;
2122
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
2223
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
2324
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
24-
import org.elasticsearch.xpack.inference.services.openai.MistralChatCompletionResponseHandler;
2525
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
2626

2727
import java.util.Map;
@@ -38,7 +38,7 @@ public class MistralActionCreator implements MistralActionVisitor {
3838

3939
public static final String COMPLETION_ERROR_PREFIX = "Mistral completions";
4040
static final String USER_ROLE = "user";
41-
static final ResponseHandler COMPLETION_HANDLER = new MistralChatCompletionResponseHandler(
41+
static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
4242
"mistral completions",
4343
OpenAiChatCompletionResponseEntity::fromResponse
4444
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsSer
5858
setPropertiesFromServiceSettings(serviceSettings);
5959
}
6060

61-
protected void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
61+
private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
6262
this.model = serviceSettings.modelId();
6363
this.rateLimitSettings = serviceSettings.rateLimitSettings();
6464
setEndpointUrl();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public HttpRequest createHttpRequest() {
4747
);
4848
httpPost.setEntity(byteEntity);
4949

50-
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
50+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
5151
httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey()));
5252

5353
return new HttpRequest(httpPost, getInferenceEntityId());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.mistral.action;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.elasticsearch.ElasticsearchException;
12+
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.inference.InferenceServiceResults;
16+
import org.elasticsearch.test.ESTestCase;
17+
import org.elasticsearch.test.http.MockResponse;
18+
import org.elasticsearch.test.http.MockWebServer;
19+
import org.elasticsearch.threadpool.ThreadPool;
20+
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22+
import org.elasticsearch.xpack.inference.common.TruncatorTests;
23+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
24+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
25+
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
26+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
27+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
28+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
29+
import org.elasticsearch.xpack.inference.services.mistral.completions.MistralChatCompletionModelTests;
30+
import org.junit.After;
31+
import org.junit.Before;
32+
33+
import java.io.IOException;
34+
import java.util.List;
35+
import java.util.Map;
36+
import java.util.concurrent.TimeUnit;
37+
38+
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
39+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
40+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
41+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
42+
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
43+
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
44+
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
45+
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
46+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
47+
import static org.hamcrest.Matchers.equalTo;
48+
import static org.hamcrest.Matchers.hasSize;
49+
import static org.hamcrest.Matchers.is;
50+
import static org.mockito.Mockito.mock;
51+
52+
public class MistralActionCreatorTests extends ESTestCase {
53+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
54+
private final MockWebServer webServer = new MockWebServer();
55+
private ThreadPool threadPool;
56+
private HttpClientManager clientManager;
57+
58+
@Before
59+
public void init() throws Exception {
60+
webServer.start();
61+
threadPool = createThreadPool(inferenceUtilityPool());
62+
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
63+
}
64+
65+
@After
66+
public void shutdown() throws IOException {
67+
clientManager.close();
68+
terminate(threadPool);
69+
webServer.close();
70+
}
71+
72+
public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
73+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
74+
75+
try (var sender = createSender(senderFactory)) {
76+
sender.start();
77+
78+
String responseJson = """
79+
{
80+
"object": "chat.completion",
81+
"id": "",
82+
"created": 1745855316,
83+
"model": "/repository",
84+
"system_fingerprint": "3.2.3-sha-a1f3ebe",
85+
"choices": [
86+
{
87+
"index": 0,
88+
"message": {
89+
"role": "assistant",
90+
"content": "Hello there, how may I assist you today?"
91+
},
92+
"logprobs": null,
93+
"finish_reason": "stop"
94+
}
95+
],
96+
"usage": {
97+
"prompt_tokens": 8,
98+
"completion_tokens": 50,
99+
"total_tokens": 58
100+
}
101+
}
102+
""";
103+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
104+
105+
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
106+
107+
var result = listener.actionGet(TIMEOUT);
108+
109+
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
110+
111+
assertChatCompletionRequest();
112+
}
113+
}
114+
115+
public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException {
116+
var settings = buildSettingsWithRetryFields(
117+
TimeValue.timeValueMillis(1),
118+
TimeValue.timeValueMinutes(1),
119+
TimeValue.timeValueSeconds(0)
120+
);
121+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
122+
123+
try (var sender = createSender(senderFactory)) {
124+
sender.start();
125+
126+
String responseJson = """
127+
{
128+
"invalid_field": "unexpected"
129+
}
130+
""";
131+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
132+
133+
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
134+
sender,
135+
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
136+
);
137+
138+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
139+
assertThat(
140+
thrownException.getMessage(),
141+
is("Failed to send Mistral completion request from inference entity id " + "[id]. Cause: Required [choices]")
142+
);
143+
144+
assertChatCompletionRequest();
145+
}
146+
}
147+
148+
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
149+
var model = MistralChatCompletionModelTests.createCompletionModel("secret", "model");
150+
model.setURI(getUrl(webServer));
151+
var actionCreator = new MistralActionCreator(sender, threadPool);
152+
var action = actionCreator.create(model);
153+
154+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
155+
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
156+
return listener;
157+
}
158+
159+
private void assertChatCompletionRequest() throws IOException {
160+
assertThat(webServer.requests(), hasSize(1));
161+
assertNull(webServer.requests().get(0).getUri().getQuery());
162+
assertThat(
163+
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
164+
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
165+
);
166+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
167+
168+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
169+
assertThat(requestMap.size(), is(4));
170+
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
171+
assertThat(requestMap.get("model"), is("model"));
172+
assertThat(requestMap.get("n"), is(1));
173+
assertThat(requestMap.get("stream"), is(false));
174+
}
175+
}

0 commit comments

Comments
 (0)