Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/changelog/128538.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128538
summary: "Added Mistral Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ static TransportVersion def(int id) {
public static final TransportVersion ILM_ADD_SKIP_SETTING_8_19 = def(8_841_0_43);
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY_8_19 = def(8_841_0_44);
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public record UnifiedCompletionRequest(
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
*/
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
/**
* Indicates whether to include the `stream_options` field in the JSON output.
* Some providers do not support this field. In such cases, this parameter should be set to "false",
* and the `stream_options` field will be excluded from the output.
* For providers that do support stream options, this parameter is left unset (default behavior),
* which implicitly includes the `stream_options` field in the output.
*/
public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options";

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
Expand All @@ -91,6 +99,23 @@ public static Params withMaxTokens(String modelId, Params params) {
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(
Map.entry(MODEL_ID_PARAM, modelId),
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
),
params
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(13));
assertThat(services.size(), equalTo(14));

var providers = providers(services);

Expand All @@ -154,15 +154,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"openai",
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker"
"amazon_sagemaker",
"mistral"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -176,7 +177,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai"
"googlevertexai",
"mistral"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
Expand Down Expand Up @@ -266,6 +267,13 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
MistralEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MistralChatCompletionServiceSettings.NAME,
MistralChatCompletionServiceSettings::new
)
);

// note - no task settings for Mistral embeddings...
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
* A pattern is emerging in how external providers provide error responses.
*
* At a minimum, these return:
* <pre><code>
* {
* "error: {
* "message": "(error message)"
* }
* }
*
* </code></pre>
* Others may return additional information such as error codes specific to the service.
*
* This currently covers error handling for Azure AI Studio, however this pattern
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;

import java.util.Objects;
import java.util.Optional;

/**
* Represents an error response from a streaming inference service.
* This class extends {@link ErrorResponse} and provides additional fields
* specific to streaming errors, such as code, param, and type.
* An example error response for a streaming service might look like:
* <pre><code>
* {
* "error": {
* "message": "Invalid input",
* "code": "400",
* "param": "input",
* "type": "invalid_request_error"
* }
* }
* </code></pre>
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
*/
public class StreamingErrorResponse extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
);
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
);

static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));

ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}

/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as an HttpResult
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as a string
*/
public static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final String code;
@Nullable
private final String param;
private final String type;

StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
this.type = Objects.requireNonNull(type);
}

@Nullable
public String code() {
return code;
}

@Nullable
public String param() {
return param;
}

public String type() {
return type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.inference.UnifiedCompletionRequest.INCLUDE_STREAM_OPTIONS_PARAM;

/**
* Represents a unified chat completion request entity.
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
*/
public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {

public static final String STREAM_FIELD = "stream";
Expand Down Expand Up @@ -42,7 +48,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

builder.field(STREAM_FIELD, stream);
if (stream) {
// If request is streamed and skip stream options parameter is not true, include stream options in the request.
if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)) {
builder.startObject(STREAM_OPTIONS_FIELD);
builder.field(INCLUDE_USAGE_FIELD, true);
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.mistral;

import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;

/**
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
* This class is specifically designed to handle Mistral's error response format.
*/
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {

/**
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "mistral completions").
* @param parseFunction The function to parse the response.
*/
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

public class MistralConstants {
public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings";
public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions";

// note - there is no bounds information available from Mistral,
// so we'll use a sane default here which is the same as Cohere's
Expand All @@ -18,4 +19,8 @@ public class MistralConstants {
public static final String MODEL_FIELD = "model";
public static final String INPUT_FIELD = "input";
public static final String ENCODING_FORMAT_FIELD = "encoding_format";
public static final String MAX_TOKENS_FIELD = "max_tokens";
public static final String DETAIL_FIELD = "detail";
public static final String MSG_FIELD = "msg";
public static final String MESSAGE_FIELD = "message";
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity;

import java.util.List;
Expand Down
Loading