Skip to content

Commit d5d972b

Browse files
Merge branch 'main' into fix_sort_tests
2 parents 6162d6d + 767d53f commit d5d972b

File tree

45 files changed

+2615
-364
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2615
-364
lines changed

docs/changelog/128538.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 128538
2+
summary: "Added Mistral Chat Completion support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

muted-tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ tests:
204204
- class: org.elasticsearch.packaging.test.DockerTests
205205
method: test151MachineDependentHeapWithSizeOverride
206206
issue: https://github.com/elastic/elasticsearch/issues/123437
207+
- class: org.elasticsearch.action.admin.cluster.node.tasks.CancellableTasksIT
208+
method: testChildrenTasksCancelledOnTimeout
209+
issue: https://github.com/elastic/elasticsearch/issues/123568
207210
- class: org.elasticsearch.xpack.searchablesnapshots.FrozenSearchableSnapshotsIntegTests
208211
method: testCreateAndRestorePartialSearchableSnapshot
209212
issue: https://github.com/elastic/elasticsearch/issues/123773

server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.action.ActionRequestValidationException;
1818
import org.elasticsearch.action.ActionResponse;
1919
import org.elasticsearch.action.ActionType;
20-
import org.elasticsearch.action.DelegatingActionListener;
2120
import org.elasticsearch.action.LatchedActionListener;
2221
import org.elasticsearch.action.LegacyActionRequest;
2322
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
@@ -27,7 +26,6 @@
2726
import org.elasticsearch.action.support.PlainActionFuture;
2827
import org.elasticsearch.client.internal.node.NodeClient;
2928
import org.elasticsearch.cluster.node.DiscoveryNode;
30-
import org.elasticsearch.common.Strings;
3129
import org.elasticsearch.common.io.stream.StreamInput;
3230
import org.elasticsearch.common.io.stream.StreamOutput;
3331
import org.elasticsearch.common.util.CollectionUtils;
@@ -36,8 +34,6 @@
3634
import org.elasticsearch.common.util.set.Sets;
3735
import org.elasticsearch.core.TimeValue;
3836
import org.elasticsearch.injection.guice.Inject;
39-
import org.elasticsearch.logging.LogManager;
40-
import org.elasticsearch.logging.Logger;
4137
import org.elasticsearch.plugins.ActionPlugin;
4238
import org.elasticsearch.plugins.Plugin;
4339
import org.elasticsearch.tasks.CancellableTask;
@@ -47,7 +43,6 @@
4743
import org.elasticsearch.tasks.TaskInfo;
4844
import org.elasticsearch.tasks.TaskManager;
4945
import org.elasticsearch.test.ESIntegTestCase;
50-
import org.elasticsearch.test.junit.annotations.TestIssueLogging;
5146
import org.elasticsearch.threadpool.ThreadPool;
5247
import org.elasticsearch.transport.ReceiveTimeoutTransportException;
5348
import org.elasticsearch.transport.SendRequestTransportException;
@@ -79,9 +74,6 @@
7974
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST)
8075
public class CancellableTasksIT extends ESIntegTestCase {
8176

82-
// Temporary addition for investigation into https://github.com/elastic/elasticsearch/issues/123568
83-
private static final Logger logger = LogManager.getLogger(CancellableTasksIT.class);
84-
8577
static int idGenerator = 0;
8678
static final Map<TestRequest, CountDownLatch> beforeSendLatches = ConcurrentCollections.newConcurrentMap();
8779
static final Map<TestRequest, CountDownLatch> arrivedLatches = ConcurrentCollections.newConcurrentMap();
@@ -374,42 +366,18 @@ public void testRemoveBanParentsOnDisconnect() throws Exception {
374366
}
375367
}
376368

377-
@TestIssueLogging(
378-
issueUrl = "https://github.com/elastic/elasticsearch/issues/123568",
379-
value = "org.elasticsearch.transport.TransportService.tracer:TRACE"
380-
+ ",org.elasticsearch.tasks.TaskManager:TRACE"
381-
+ ",org.elasticsearch.action.admin.cluster.node.tasks.CancellableTasksIT:DEBUG"
382-
)
383369
public void testChildrenTasksCancelledOnTimeout() throws Exception {
384370
Set<DiscoveryNode> nodes = clusterService().state().nodes().stream().collect(Collectors.toSet());
385371
final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), true);
386-
logger.info("generated request\n{}", buildTestRequestDescription(rootRequest, "", new StringBuilder()).toString());
387372
ActionFuture<TestResponse> rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest);
388-
logger.info("action executed");
389373
allowEntireRequest(rootRequest);
390-
logger.info("execution released");
391374
waitForRootTask(rootTaskFuture, true);
392-
logger.info("root task completed");
393375
ensureBansAndCancellationsConsistency();
394-
logger.info("ensureBansAndCancellationsConsistency completed");
395376

396377
// Make sure all descendent requests have completed
397378
for (TestRequest subRequest : rootRequest.descendants()) {
398-
logger.info("awaiting completion of request {}", subRequest.getDescription());
399379
safeAwait(completedLatches.get(subRequest));
400380
}
401-
logger.info("all requests completed");
402-
}
403-
404-
// Temporary addition for investigation into https://github.com/elastic/elasticsearch/issues/123568
405-
static StringBuilder buildTestRequestDescription(TestRequest request, String prefix, StringBuilder stringBuilder) {
406-
stringBuilder.append(prefix)
407-
.append(Strings.format("id=%d [timeout=%s] %s", request.id, request.timeout, request.node.descriptionWithoutAttributes()))
408-
.append('\n');
409-
for (TestRequest subRequest : request.subRequests) {
410-
buildTestRequestDescription(subRequest, prefix + " ", stringBuilder);
411-
}
412-
return stringBuilder;
413381
}
414382

415383
static TaskId getRootTaskId(TestRequest request) throws Exception {
@@ -538,8 +506,6 @@ public void writeTo(StreamOutput out) throws IOException {
538506

539507
public static class TransportTestAction extends HandledTransportAction<TestRequest, TestResponse> {
540508

541-
private static final Logger logger = CancellableTasksIT.logger;
542-
543509
public static ActionType<TestResponse> ACTION = new ActionType<>("internal::test_action");
544510
private final TransportService transportService;
545511
private final NodeClient client;
@@ -599,22 +565,7 @@ protected void doExecute(Task task, TestRequest request, ActionListener<TestResp
599565
protected void startSubTask(TaskId parentTaskId, TestRequest subRequest, ActionListener<TestResponse> listener) {
600566
subRequest.setParentTask(parentTaskId);
601567
CountDownLatch completeLatch = completedLatches.get(subRequest);
602-
ActionListener<TestResponse> latchedListener = new DelegatingActionListener<>(
603-
new LatchedActionListener<>(listener, completeLatch)
604-
) {
605-
// Temporary logging addition for investigation into https://github.com/elastic/elasticsearch/issues/123568
606-
@Override
607-
public void onResponse(TestResponse testResponse) {
608-
logger.debug("processing successful response to request [{}]", subRequest.getDescription());
609-
delegate.onResponse(testResponse);
610-
}
611-
612-
@Override
613-
public void onFailure(Exception e) {
614-
logger.debug("processing exceptional response to request [{}]: {}", subRequest.getDescription(), e.getMessage());
615-
super.onFailure(e);
616-
}
617-
};
568+
LatchedActionListener<TestResponse> latchedListener = new LatchedActionListener<>(listener, completeLatch);
618569
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
619570
@Override
620571
public void onFailure(Exception e) {
@@ -645,13 +596,6 @@ protected void doRun() throws Exception {
645596
TransportResponseHandler.TRANSPORT_WORKER
646597
)
647598
);
648-
// Temporary addition for investigation into https://github.com/elastic/elasticsearch/issues/123568
649-
logger.debug(
650-
"sent test request [{}] from [{}] to [{}]",
651-
subRequest.getDescription(),
652-
client.getLocalNodeId(),
653-
subRequest.node.descriptionWithoutAttributes()
654-
);
655599
}
656600
}
657601
});

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ static TransportVersion def(int id) {
189189
public static final TransportVersion DATA_STREAM_OPTIONS_API_REMOVE_INCLUDE_DEFAULTS_8_19 = def(8_841_0_41);
190190
public static final TransportVersion JOIN_ON_ALIASES_8_19 = def(8_841_0_42);
191191
public static final TransportVersion ILM_ADD_SKIP_SETTING_8_19 = def(8_841_0_43);
192+
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_44);
192193
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
193194
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
194195
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -282,6 +283,7 @@ static TransportVersion def(int id) {
282283
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES = def(9_087_0_00);
283284
public static final TransportVersion JOIN_ON_ALIASES = def(9_088_0_00);
284285
public static final TransportVersion ILM_ADD_SKIP_SETTING = def(9_089_0_00);
286+
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00);
285287

286288
/*
287289
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ public record UnifiedCompletionRequest(
7878
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
7979
*/
8080
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
81+
/**
82+
* Indicates whether to include the `stream_options` field in the JSON output.
83+
* Some providers do not support this field. In such cases, this parameter should be set to "false",
84+
* and the `stream_options` field will be excluded from the output.
85+
* For providers that do support stream options, this parameter is left unset (default behavior),
86+
* which implicitly includes the `stream_options` field in the output.
87+
*/
88+
public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options";
8189

8290
/**
8391
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
@@ -91,6 +99,23 @@ public static Params withMaxTokens(String modelId, Params params) {
9199
);
92100
}
93101

102+
/**
103+
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
104+
* - Key: {@link #MODEL_FIELD}, Value: modelId
105+
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
106+
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
107+
*/
108+
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
109+
return new DelegatingMapParams(
110+
Map.ofEntries(
111+
Map.entry(MODEL_ID_PARAM, modelId),
112+
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
113+
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
114+
),
115+
params
116+
);
117+
}
118+
94119
/**
95120
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
96121
* - Key: {@link #MODEL_FIELD}, Value: modelId

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ public void testSimpleSortEvalSumLiteralAndField() {
193193
assertNoPushdownSort(query.asTimeSeries(), "for time series index mode");
194194
}
195195

196-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114515")
197196
public void testPartiallyPushableSort() {
198197
// FROM index | EVAL sum = 1 + integer | SORT integer, sum, field | LIMIT 10
199198
var query = from("index").eval("sum", b -> b.add(b.i(1), b.field("integer"))).sort("integer").sort("sum").sort("field").limit(10);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
134134

135135
public void testGetServicesWithCompletionTaskType() throws IOException {
136136
List<Object> services = getServices(TaskType.COMPLETION);
137-
assertThat(services.size(), equalTo(13));
137+
assertThat(services.size(), equalTo(14));
138138

139139
var providers = providers(services);
140140

@@ -154,15 +154,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
154154
"openai",
155155
"streaming_completion_test_service",
156156
"hugging_face",
157-
"amazon_sagemaker"
157+
"amazon_sagemaker",
158+
"mistral"
158159
).toArray()
159160
)
160161
);
161162
}
162163

163164
public void testGetServicesWithChatCompletionTaskType() throws IOException {
164165
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
165-
assertThat(services.size(), equalTo(7));
166+
assertThat(services.size(), equalTo(8));
166167

167168
var providers = providers(services);
168169

@@ -176,7 +177,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
176177
"streaming_completion_test_service",
177178
"hugging_face",
178179
"amazon_sagemaker",
179-
"googlevertexai"
180+
"googlevertexai",
181+
"mistral"
180182
).toArray()
181183
)
182184
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
101101
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
102102
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
103+
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
103104
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
104105
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
105106
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
@@ -266,6 +267,13 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
266267
MistralEmbeddingsServiceSettings::new
267268
)
268269
);
270+
namedWriteables.add(
271+
new NamedWriteableRegistry.Entry(
272+
ServiceSettings.class,
273+
MistralChatCompletionServiceSettings.NAME,
274+
MistralChatCompletionServiceSettings::new
275+
)
276+
);
269277

270278
// note - no task settings for Mistral embeddings...
271279
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
* A pattern is emerging in how external providers provide error responses.
2222
*
2323
* At a minimum, these return:
24+
* <pre><code>
2425
* {
2526
* "error: {
2627
* "message": "(error message)"
2728
* }
2829
* }
29-
*
30+
* </code></pre>
3031
* Others may return additional information such as error codes specific to the service.
3132
*
3233
* This currently covers error handling for Azure AI Studio, however this pattern

0 commit comments

Comments
 (0)