diff --git a/docs/changelog/135701.yaml b/docs/changelog/135701.yaml new file mode 100644 index 0000000000000..8285b9d69b74a --- /dev/null +++ b/docs/changelog/135701.yaml @@ -0,0 +1,5 @@ +pr: 135701 +summary: Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 91543710d695e..bd672f8ed138d 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -89,31 +89,34 @@ public record UnifiedCompletionRequest( /** * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: - * - Key: {@link #MODEL_FIELD}, Value: modelId + * - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null * - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()} */ - public static Params withMaxTokens(String modelId, Params params) { - return new DelegatingMapParams( - Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)), - params - ); + public static Params withMaxTokens(@Nullable String modelId, Params params) { + Map entries = modelId != null + ? Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)) + : Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)); + return new DelegatingMapParams(entries, params); } /** * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: - * - Key: {@link #MODEL_FIELD}, Value: modelId - * - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD} + * - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null + * - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()} * - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false" */ - public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) { - return new DelegatingMapParams( - Map.ofEntries( + public static Params withMaxTokensAndSkipStreamOptionsField(@Nullable String modelId, Params params) { + Map entries = modelId != null + ? Map.ofEntries( Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD), Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString()) - ), - params - ); + ) + : Map.ofEntries( + Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD), + Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString()) + ); + return new DelegatingMapParams(entries, params); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequest.java index 1f744bebbd197..711b22ea281f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequest.java @@ -43,7 +43,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntity.java index 4cd2cf3f4b6f7..be7a1ce4d9f42 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntity.java @@ -7,15 +7,14 @@ package org.elasticsearch.xpack.inference.services.ai21.request; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; -import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel; import java.io.IOException; -import java.util.Objects; /** * Ai21ChatCompletionRequestEntity is responsible for creating the request entity for Ai21 chat completion. @@ -23,18 +22,18 @@ */ public class Ai21ChatCompletionRequestEntity implements ToXContentObject { - private final Ai21ChatCompletionModel model; + private final String modelId; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; - public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, Ai21ChatCompletionModel model) { + public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) { this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); - this.model = Objects.requireNonNull(model); + this.modelId = modelId; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params)); builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java index 9017faf1459fb..48a310ef6165a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java @@ -7,16 +7,178 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.ai21.request.Ai21ChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicResponseHandler; +import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleModelGardenAnthropicChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.mistral.MistralUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; + import java.util.Locale; /** * Enum representing the supported model garden providers. */ public enum GoogleModelGardenProvider { - GOOGLE, - ANTHROPIC; + GOOGLE( + CompletionResponseHandlerHolder.GOOGLE_VERTEX_AI_COMPLETION_HANDLER, + ChatCompletionResponseHandlerHolder.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER, + (unifiedChatInput, modelId, taskSettings) -> new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + taskSettings.thinkingConfig() + ) + ), + ANTHROPIC( + CompletionResponseHandlerHolder.ANTHROPIC_COMPLETION_HANDLER, + ChatCompletionResponseHandlerHolder.ANTHROPIC_CHAT_COMPLETION_HANDLER, + (unifiedChatInput, modelId, taskSettings) -> new GoogleModelGardenAnthropicChatCompletionRequestEntity( + unifiedChatInput, + taskSettings + ) + ), + META( + CompletionResponseHandlerHolder.META_COMPLETION_HANDLER, + ChatCompletionResponseHandlerHolder.META_CHAT_COMPLETION_HANDLER, + (unifiedChatInput, modelId, taskSettings) -> new LlamaChatCompletionRequestEntity(unifiedChatInput, modelId) + ), + HUGGING_FACE( + CompletionResponseHandlerHolder.HUGGING_FACE_COMPLETION_HANDLER, + ChatCompletionResponseHandlerHolder.HUGGING_FACE_CHAT_COMPLETION_HANDLER, + (unifiedChatInput, modelId, taskSettings) -> new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, modelId) + ), + MISTRAL( + CompletionResponseHandlerHolder.MISTRAL_COMPLETION_HANDLER, + ChatCompletionResponseHandlerHolder.MISTRAL_CHAT_COMPLETION_HANDLER, + (unifiedChatInput, modelId, taskSettings) -> new MistralChatCompletionRequestEntity(unifiedChatInput, modelId) + ), + AI21( + CompletionResponseHandlerHolder.AI21_COMPLETION_HANDLER, + ChatCompletionResponseHandlerHolder.AI21_CHAT_COMPLETION_HANDLER, + (unifiedChatInput, modelId, taskSettings) -> new Ai21ChatCompletionRequestEntity(unifiedChatInput, modelId) + ); + + private final ResponseHandler completionResponseHandler; + private final ResponseHandler chatCompletionResponseHandler; + private final RequestEntityCreator entityCreator; + + GoogleModelGardenProvider( + ResponseHandler completionResponseHandler, + ResponseHandler chatCompletionResponseHandler, + RequestEntityCreator entityCreator + ) { + this.completionResponseHandler = completionResponseHandler; + this.chatCompletionResponseHandler = chatCompletionResponseHandler; + this.entityCreator = entityCreator; + } + + public ResponseHandler getCompletionResponseHandler() { + return completionResponseHandler; + } + + public ResponseHandler getChatCompletionResponseHandler() { + return chatCompletionResponseHandler; + } + + public ToXContentObject createRequestEntity( + UnifiedChatInput unifiedChatInput, + @Nullable String modelId, + GoogleVertexAiChatCompletionTaskSettings taskSettings + ) { + return entityCreator.create(unifiedChatInput, modelId, taskSettings); + } + + private static class CompletionResponseHandlerHolder { + static final ResponseHandler GOOGLE_VERTEX_AI_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( + "Google Vertex AI completion", + GoogleVertexAiCompletionResponseEntity::fromResponse, + GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse, + true + ); + + static final ResponseHandler ANTHROPIC_COMPLETION_HANDLER = new AnthropicResponseHandler( + "Google Model Garden Anthropic completion", + AnthropicChatCompletionResponseEntity::fromResponse, + true + ); + + static final ResponseHandler META_COMPLETION_HANDLER = new LlamaCompletionResponseHandler( + "Google Model Garden Meta completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); - public static final String NAME = "google_model_garden_provider"; + static final ResponseHandler HUGGING_FACE_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( + "Google Model Garden Hugging Face completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + static final ResponseHandler MISTRAL_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( + "Google Model Garden Mistral completion", + OpenAiChatCompletionResponseEntity::fromResponse, + ErrorResponse::fromResponse + ); + + static final ResponseHandler AI21_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( + "Google Model Garden AI21 completion", + OpenAiChatCompletionResponseEntity::fromResponse, + ErrorResponse::fromResponse + ); + } + + private static class ChatCompletionResponseHandlerHolder { + static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( + "Google Vertex AI chat completion" + ); + + static final ResponseHandler ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler( + "Google Model Garden Anthropic chat completion" + ); + + static final ResponseHandler META_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler( + "Google Model Garden Meta chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + static final ResponseHandler HUGGING_FACE_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler( + "Google Model Garden Hugging Face chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + static final ResponseHandler MISTRAL_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler( + "Google Model Garden Mistral chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler( + "Google Model Garden AI21 chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); + } + + @FunctionalInterface + private interface RequestEntityCreator { + ToXContentObject create( + UnifiedChatInput unifiedChatInput, + @Nullable String modelId, + GoogleVertexAiChatCompletionTaskSettings taskSettings + ); + } public static GoogleModelGardenProvider fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 66a4dd0649730..117bed24dfbbe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -33,7 +33,6 @@ import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -43,7 +42,6 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; @@ -92,14 +90,6 @@ public class GoogleVertexAiService extends SenderService implements RerankingInf InputType.INTERNAL_SEARCH ); - public static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( - "Google Vertex AI chat completion" - ); - - public static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler( - "Google Model Garden Anthropic chat completion" - ); - @Override public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); @@ -265,28 +255,16 @@ protected void doUnifiedCompletionInfer( } } + /** + * Helper method to create a GenericRequestManager with a specified response handler. + * @param model The GoogleVertexAiChatCompletionModel to be used for requests. + * @return A GenericRequestManager configured with the provided response handler. + */ private GenericRequestManager createRequestManager(GoogleVertexAiChatCompletionModel model) { - switch (model.getServiceSettings().provider()) { - case ANTHROPIC -> { - return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER); - } - case GOOGLE -> { - return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER); - } - case null, default -> throw new ElasticsearchException( - "Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() - ); - } - } - - private GenericRequestManager createRequestManagerWithHandler( - GoogleVertexAiChatCompletionModel model, - ResponseHandler responseHandler - ) { return new GenericRequestManager<>( getServiceComponents().threadPool(), model, - responseHandler, + model.getServiceSettings().provider().getChatCompletionResponseHandler(), unifiedChatInput -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, model), UnifiedChatInput.class ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index f06dcdc46a926..87bd39f8b45ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -7,27 +7,20 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.action; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.anthropic.AnthropicResponseHandler; -import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; -import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiResponseHandler; -import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; -import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; import java.util.Map; import java.util.Objects; @@ -41,19 +34,6 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor private final ServiceComponents serviceComponents; - static final ResponseHandler GOOGLE_VERTEX_AI_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( - "Google Vertex AI completion", - GoogleVertexAiCompletionResponseEntity::fromResponse, - GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse, - true - ); - - static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER = new AnthropicResponseHandler( - "Google Model Garden Anthropic completion", - AnthropicChatCompletionResponseEntity::fromResponse, - true - ); - static final String USER_ROLE = "user"; public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { @@ -90,28 +70,11 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map createRequestManager(GoogleVertexAiChatCompletionModel model) { - switch (model.getServiceSettings().provider()) { - case ANTHROPIC -> { - return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER); - } - case GOOGLE -> { - return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_COMPLETION_HANDLER); - } - case null, default -> throw new ElasticsearchException( - "Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() - ); - } - } - - private GenericRequestManager createRequestManagerWithHandler( - GoogleVertexAiChatCompletionModel overriddenModel, - ResponseHandler responseHandler - ) { return new GenericRequestManager<>( serviceComponents.threadPool(), - overriddenModel, - responseHandler, - inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel), + model, + model.getServiceSettings().provider().getCompletionResponseHandler(), + inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), ChatCompletionInput.class ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequest.java index cdf36ad76dac5..2b6144164a53a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -10,7 +10,6 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentType; @@ -41,7 +40,9 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(uri); ToXContentObject requestEntity; - requestEntity = createRequestEntity(); + requestEntity = model.getServiceSettings() + .provider() + .createRequestEntity(unifiedChatInput, extractModelId(), model.getTaskSettings()); ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); @@ -52,18 +53,13 @@ public HttpRequest createHttpRequest() { return new HttpRequest(httpPost, getInferenceEntityId()); } - private ToXContentObject createRequestEntity() { - switch (model.getServiceSettings().provider()) { - case ANTHROPIC -> { - return new GoogleModelGardenAnthropicChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings()); - } - case GOOGLE -> { - return new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings().thinkingConfig()); - } - case null, default -> throw new ElasticsearchException( - "Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() - ); - } + /** + * Extracts the model ID to be used for the request. If the request contains a model ID, it is preferred. + * Otherwise, the model ID from the configuration is used. + * @return the model ID to be used for the request + */ + private String extractModelId() { + return unifiedChatInput.getRequest().model() != null ? unifiedChatInput.getRequest().model() : model.getServiceSettings().modelId(); } public void decorateWithAuth(HttpPost httpPost) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java index 71ce78b8bd0b2..fd3607395765f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java @@ -50,7 +50,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(getURI()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java index e71cfb5dccf00..c37d14f196161 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java @@ -7,30 +7,39 @@ package org.elasticsearch.xpack.inference.services.huggingface.request.completion; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; -import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import java.io.IOException; -import java.util.Objects; +/** + * HuggingFaceUnifiedChatCompletionRequestEntity is responsible for creating the request entity for Hugging Face chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject { - private final HuggingFaceChatCompletionModel model; + private final String modelId; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; - public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) { + /** + * Constructs a HuggingFaceUnifiedChatCompletionRequestEntity with the specified unified chat input and model ID. + * + * @param unifiedChatInput the unified chat input containing messages and parameters for the completion request + * @param modelId the Hugging Face chat completion model ID to be used for the request + */ + public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) { this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); - this.model = Objects.requireNonNull(model); + this.modelId = modelId; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params)); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java index 3bb01f215087e..b1edd6f299eb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java @@ -55,7 +55,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new LlamaChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new LlamaChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java index fc80dab09f6f5..7d4534f2ce4d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java @@ -7,15 +7,14 @@ package org.elasticsearch.xpack.inference.services.llama.request.completion; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; -import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; import java.io.IOException; -import java.util.Objects; /** * LlamaChatCompletionRequestEntity is responsible for creating the request entity for Llama chat completion. @@ -23,24 +22,24 @@ */ public class LlamaChatCompletionRequestEntity implements ToXContentObject { - private final LlamaChatCompletionModel model; + private final String modelId; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; /** * Constructs a LlamaChatCompletionRequestEntity with the specified unified chat input and model. * * @param unifiedChatInput the unified chat input containing messages and parameters for the completion request - * @param model the Llama chat completion model to be used for the request + * @param modelId the Llama chat completion model id to be used for the request */ - public LlamaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, LlamaChatCompletionModel model) { + public LlamaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) { this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); - this.model = Objects.requireNonNull(model); + this.modelId = modelId; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(modelId, params)); builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java index 64051ee0d83b1..be60e07d01c27 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java @@ -43,7 +43,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new MistralChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new MistralChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java index 3fe640335c47e..50278abbf7944 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java @@ -7,15 +7,14 @@ package org.elasticsearch.xpack.inference.services.mistral.request.completion; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; -import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import java.io.IOException; -import java.util.Objects; /** * MistralChatCompletionRequestEntity is responsible for creating the request entity for Mistral chat completion. @@ -23,21 +22,18 @@ */ public class MistralChatCompletionRequestEntity implements ToXContentObject { - private final MistralChatCompletionModel model; + private final String modelId; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; - public MistralChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, MistralChatCompletionModel model) { + public MistralChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) { this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); - this.model = Objects.requireNonNull(model); + this.modelId = modelId; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent( - builder, - UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params) - ); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(modelId, params)); builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntityTests.java index 88f7ac11a0d0b..e70e1de605cb5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntityTests.java @@ -15,17 +15,78 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel; import java.io.IOException; import java.util.ArrayList; -import static org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModelTests.createCompletionModel; - public class Ai21ChatCompletionRequestEntityTests extends ESTestCase { private static final String ROLE = "user"; - public void testModelUserFieldsSerialization() throws IOException { + public void testSerializationWithModelIdStreaming() throws IOException { + testSerialization("test-model", true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-model", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """); + } + + public void testSerializationWithModelIdNonStreaming() throws IOException { + testSerialization("test-model", false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-model", + "n": 1, + "stream": false + } + """); + } + + public void testSerializationWithoutModelIdStreaming() throws IOException { + testSerialization(null, true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """); + } + + public void testSerializationWithoutModelIdNonStreaming() throws IOException { + testSerialization(null, false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": false + } + """); + } + + private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, @@ -37,28 +98,12 @@ public void testModelUserFieldsSerialization() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - Ai21ChatCompletionModel model = createCompletionModel("api-key", "test-model"); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, isStreaming); - Ai21ChatCompletionRequestEntity entity = new Ai21ChatCompletionRequestEntity(unifiedChatInput, model); + Ai21ChatCompletionRequestEntity entity = new Ai21ChatCompletionRequestEntity(unifiedChatInput, modelId); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - String expectedJson = """ - { - "messages": [{ - "content": "Hello, world!", - "role": "user" - } - ], - "model": "test-model", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java index 4d1c590390648..f63de923b5f5c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java @@ -27,7 +27,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleModelGardenProvider; -import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest; @@ -36,16 +35,12 @@ import org.junit.Before; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER; -import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER; import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.USER_ROLE; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; @@ -82,120 +77,103 @@ private static UnifiedChatInput createUnifiedChatInput(List messages) { // Successful case would typically be tested via end-to-end notebook tests in AppEx repo public void testExecute_ThrowsElasticsearchExceptionGoogleVertexAi() { - testExecute_ThrowsElasticsearchException( - "us-central1", - "test-project-id", - "chat-bison", - null, - null, - GoogleVertexAiService.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER - ); + testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider.GOOGLE); } public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledGoogleVertexAi() { - testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled( - "us-central1", - "test-project-id", - "chat-bison", - null, - null, - GoogleVertexAiService.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER - ); + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider.GOOGLE); } public void testExecute_ThrowsExceptionGoogleVertexAi() { - testExecute_ThrowsException("us-central1", "test-project-id", "chat-bison", null, null, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER); + testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider.GOOGLE); } - public void testExecute_ThrowsElasticsearchExceptionAnthropic() throws URISyntaxException { - testExecute_ThrowsElasticsearchException( - null, - null, - null, - GoogleModelGardenProvider.ANTHROPIC, - new URI("http://localhost:9200"), - GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER - ); + public void testExecute_ThrowsElasticsearchExceptionAnthropic() { + testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider.ANTHROPIC); } - public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledAnthropic() throws URISyntaxException { - testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled( - null, - null, - null, - GoogleModelGardenProvider.ANTHROPIC, - new URI("http://localhost:9200"), - GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER - ); + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledAnthropic() { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider.ANTHROPIC); } - public void testExecute_ThrowsExceptionAnthropic() throws URISyntaxException { - testExecute_ThrowsException( - null, - null, - null, - GoogleModelGardenProvider.ANTHROPIC, - new URI("http://localhost:9200"), - GoogleVertexAiActionCreator.GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER - ); + public void testExecute_ThrowsExceptionAnthropic() { + testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider.ANTHROPIC); } - private void testExecute_ThrowsException( - String location, - String projectId, - String actualModelId, - GoogleModelGardenProvider provider, - URI uri, - ResponseHandler handler - ) { - var sender = mock(Sender.class); - doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + public void testExecute_ThrowsElasticsearchExceptionMeta() { + testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider.META); + } - var action = createAction(location, projectId, actualModelId, sender, provider, uri, handler); + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledMeta() { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider.META); + } - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(createUnifiedChatInput(List.of("test query")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + public void testExecute_ThrowsExceptionMeta() { + testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider.META); + } - var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); + public void testExecute_ThrowsElasticsearchExceptionMistral() { + testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider.MISTRAL); } - private void testExecute_ThrowsElasticsearchException( - String location, - String projectId, - String actualModelId, - GoogleModelGardenProvider googleModelGardenProvider, - URI uri, - ResponseHandler googleModelGardenAnthropicCompletionHandler - ) { - var sender = mock(Sender.class); - doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); - - var action = createAction( - location, - projectId, - actualModelId, - sender, - googleModelGardenProvider, - uri, - googleModelGardenAnthropicCompletionHandler + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledMistral() { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider.MISTRAL); + } + + public void testExecute_ThrowsExceptionMistral() { + testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider.MISTRAL); + } + + public void testExecute_ThrowsElasticsearchExceptionHuggingFace() { + testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider.HUGGING_FACE); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledHuggingFace() { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider.HUGGING_FACE); + } + + public void testExecute_ThrowsExceptionHuggingFace() { + testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider.HUGGING_FACE); + } + + public void testExecute_ThrowsElasticsearchExceptionAi21() { + testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider.AI21); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledAi21() { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider.AI21); + } + + public void testExecute_ThrowsExceptionAi21() { + testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider.AI21); + } + + private void testExecute_ThrowsIllegalArgumentException(GoogleModelGardenProvider provider) { + testExecute_ThrowsException( + provider, + new IllegalArgumentException("failed"), + "Failed to send Google Vertex AI chat completion request. Cause: failed" ); + } + + private void testExecute_ThrowsElasticsearchException(GoogleModelGardenProvider provider) { + testExecute_ThrowsException(provider, new ElasticsearchException("failed"), "failed"); + } + + private void testExecute_ThrowsException(GoogleModelGardenProvider provider, Exception exception, String expectedExceptionMessage) { + var sender = mock(Sender.class); + doThrow(exception).when(sender).send(any(), any(), any(), any()); + + var action = createAction(sender, provider, provider.getChatCompletionResponseHandler()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(createUnifiedChatInput(List.of("test query")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(thrownException.getMessage(), is("failed")); + assertThat(thrownException.getMessage(), is(expectedExceptionMessage)); } - private void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled( - String location, - String projectId, - String actualModelId, - GoogleModelGardenProvider googleModelGardenProvider, - URI uri, - ResponseHandler googleModelGardenAnthropicCompletionHandler - ) { + private void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled(GoogleModelGardenProvider provider) { var sender = mock(Sender.class); doAnswer(invocation -> { @@ -204,15 +182,7 @@ private void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalle return Void.TYPE; }).when(sender).send(any(), any(), any(), any()); - var action = createAction( - location, - projectId, - actualModelId, - sender, - googleModelGardenProvider, - uri, - googleModelGardenAnthropicCompletionHandler - ); + var action = createAction(sender, provider, provider.getChatCompletionResponseHandler()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(createUnifiedChatInput(List.of("test query")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -221,24 +191,16 @@ private void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalle assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); } - private ExecutableAction createAction( - String location, - String projectId, - String actualModelId, - Sender sender, - GoogleModelGardenProvider provider, - URI uri, - ResponseHandler handler - ) { + private ExecutableAction createAction(Sender sender, GoogleModelGardenProvider provider, ResponseHandler handler) { var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( - projectId, - location, - actualModelId, + null, + null, + null, "api-key", new RateLimitSettings(100), new ThinkingConfig(256), provider, - uri, + null, 123 ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java index 57584ae617953..33b121effeb5b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java @@ -35,7 +35,6 @@ protected Writeable.Reader instance @Override protected GoogleVertexAiChatCompletionServiceSettings fromMutableMap(Map mutableMap) { return GoogleVertexAiChatCompletionServiceSettings.fromMap(mutableMap, ConfigurationParseContext.PERSISTENT); - } @Override @@ -211,7 +210,13 @@ private static GoogleVertexAiChatCompletionServiceSettings createRandomWithGoogl randomOptionalString(), optionalUri, optionalUri == null ? createUri(randomString()) : createOptionalUri(randomOptionalString()), - randomFrom(GoogleModelGardenProvider.ANTHROPIC), + randomFrom( + GoogleModelGardenProvider.ANTHROPIC, + GoogleModelGardenProvider.META, + GoogleModelGardenProvider.MISTRAL, + GoogleModelGardenProvider.HUGGING_FACE, + GoogleModelGardenProvider.AI21 + ), new RateLimitSettings(randomIntBetween(1, 1000)) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java index b74ff256659a5..141f08e869d11 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -177,25 +177,47 @@ public void testOf_doesNotOverrideTaskSettings_whenNotPresent() { public void testModelCreationForAnthropicBothUrls() throws URISyntaxException { var uri = new URI("http://example.com"); var streamingUri = new URI("http://example-streaming.com"); - testModelCreationForAnthropic(uri, streamingUri, uri, streamingUri); + testModelCreation(uri, streamingUri, uri, streamingUri, GoogleModelGardenProvider.ANTHROPIC); } public void testModelCreationForAnthropicOnlyNonStreamingUrl() throws URISyntaxException { var uri = new URI("http://example.com"); - testModelCreationForAnthropic(uri, null, uri, uri); + testModelCreation(uri, null, uri, uri, GoogleModelGardenProvider.ANTHROPIC); } public void testModelCreationForAnthropicOnlyStreamingUrl() throws URISyntaxException { var streamingUri = new URI("http://example-streaming.com"); - testModelCreationForAnthropic(null, streamingUri, streamingUri, streamingUri); + testModelCreation(null, streamingUri, streamingUri, streamingUri, GoogleModelGardenProvider.ANTHROPIC); } - private static void testModelCreationForAnthropic(URI uri, URI streamingUri, URI expectedNonStreamingUri, URI expectedStreamingUri) { - var model = createAnthropicChatCompletionModel( + public void testModelCreationForMetaBothUrls() throws URISyntaxException { + var uri = new URI("http://example.com"); + var streamingUri = new URI("http://example-streaming.com"); + testModelCreation(uri, streamingUri, uri, streamingUri, GoogleModelGardenProvider.META); + } + + public void testModelCreationForMetaOnlyNonStreamingUrl() throws URISyntaxException { + var uri = new URI("http://example.com"); + testModelCreation(uri, null, uri, uri, GoogleModelGardenProvider.META); + } + + public void testModelCreationForMetaOnlyStreamingUrl() throws URISyntaxException { + var streamingUri = new URI("http://example-streaming.com"); + testModelCreation(null, streamingUri, streamingUri, streamingUri, GoogleModelGardenProvider.META); + } + + private static void testModelCreation( + URI uri, + URI streamingUri, + URI expectedNonStreamingUri, + URI expectedStreamingUri, + GoogleModelGardenProvider provider + ) { + var model = createGoogleModelGardenChatCompletionModel( DEFAULT_API_KEY, DEFAULT_RATE_LIMIT, EMPTY_THINKING_CONFIG, - GoogleModelGardenProvider.ANTHROPIC, + provider, uri, streamingUri, 123 @@ -220,7 +242,7 @@ private static void testModelCreationForAnthropic(URI uri, URI streamingUri, URI assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); assertThat(overriddenModel.getServiceSettings().uri(), is(uri)); assertThat(overriddenModel.getServiceSettings().streamingUri(), is(streamingUri)); - assertThat(overriddenModel.getServiceSettings().provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertThat(overriddenModel.getServiceSettings().provider(), is(provider)); assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(EMPTY_THINKING_CONFIG)); assertThat(overriddenModel.getTaskSettings().maxTokens(), is(123)); @@ -249,7 +271,7 @@ public static GoogleVertexAiChatCompletionModel createCompletionModel( ); } - public static GoogleVertexAiChatCompletionModel createAnthropicChatCompletionModel( + public static GoogleVertexAiChatCompletionModel createGoogleModelGardenChatCompletionModel( String apiKey, RateLimitSettings rateLimitSettings, ThinkingConfig thinkingConfig, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java index 81d26036036c6..c5abfa87b84f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java @@ -8,26 +8,93 @@ package org.elasticsearch.xpack.inference.services.huggingface.request; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity; import java.io.IOException; import java.util.ArrayList; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; -import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel; public class HuggingFaceUnifiedChatCompletionRequestEntityTests extends ESTestCase { private static final String ROLE = "user"; - public void testModelUserFieldsSerialization() throws IOException { + public void testSerializationWithModelIdStreaming() throws IOException { + testSerialization("test-endpoint", true, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """); + } + + public void testSerializationWithModelIdNonStreaming() throws IOException { + testSerialization("test-endpoint", false, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": false + } + """); + } + + public void testSerializationWithoutModelIdStreaming() throws IOException { + testSerialization(null, true, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """); + } + + public void testSerializationWithoutModelIdNonStreaming() throws IOException { + testSerialization(null, false, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": false + } + """); + } + + private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, @@ -39,31 +106,14 @@ public void testModelUserFieldsSerialization() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - HuggingFaceChatCompletionModel model = createCompletionModel("test-url", "api-key", "test-endpoint"); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, isStreaming); - HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, modelId); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "test-endpoint", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); + assertJsonEquals(jsonString, XContentHelper.stripWhitespace(expectedJson)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java index dd8b3d7dfa38c..4871850fb9b8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java @@ -15,8 +15,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; -import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; import java.io.IOException; import java.util.ArrayList; @@ -24,41 +22,77 @@ public class LlamaChatCompletionRequestEntityTests extends ESTestCase { private static final String ROLE = "user"; - public void testModelUserFieldsSerialization() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - - var unifiedRequest = UnifiedCompletionRequest.of(messageList); + public void testSerializationWithModelIdStreaming() throws IOException { + testSerialization("modelId", true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": true + } + """); + } - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - LlamaChatCompletionModel model = LlamaChatCompletionModelTests.createChatCompletionModel("model", "url", "api-key"); + public void testSerializationWithModelIdNonStreaming() throws IOException { + testSerialization("modelId", false, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": false + } + """); + } - LlamaChatCompletionRequestEntity entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, model); + public void testSerializationWithoutModelIdStreaming() throws IOException { + testSerialization(null, true, """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true + } + """); + } - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - String expectedJson = """ + public void testSerializationWithoutModelIdNonStreaming() throws IOException { + testSerialization(null, false, """ { "messages": [{ "content": "Hello, world!", "role": "user" } ], - "model": "model", "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } + "stream": false } - """; - assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + """); } + private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { + var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, null); + + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + var unifiedChatInput = new UnifiedChatInput(unifiedRequest, isStreaming); + + var entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, modelId); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java index 6f0701a810fb1..d4167a395911f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java @@ -36,7 +36,7 @@ public void testCreateRequest_WithStreaming() throws IOException { assertThat(requestMap.get("stream"), is(true)); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); - assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); + assertNull(requestMap.get("stream_options")); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java index f968f1b84d75b..010ae11ec845d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java @@ -15,18 +15,77 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import java.io.IOException; import java.util.ArrayList; -import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel; - public class MistralChatCompletionRequestEntityTests extends ESTestCase { private static final String ROLE = "user"; - public void testModelUserFieldsSerialization() throws IOException { + public void testSerializationWithModelIdStreaming() throws IOException { + testSerialization("test-endpoint", true, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true + } + """); + } + + public void testSerializationWithModelIdNonStreaming() throws IOException { + testSerialization("test-endpoint", false, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": false + } + """); + } + + public void testSerializationWithoutModelIdStreaming() throws IOException { + testSerialization(null, true, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true + } + """); + } + + public void testSerializationWithoutModelIdNonStreaming() throws IOException { + testSerialization(null, false, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": false + } + """); + } + + private static void testSerialization(String modelId, boolean isStreaming, String expectedJson) throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, @@ -38,26 +97,12 @@ public void testModelUserFieldsSerialization() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - MistralChatCompletionModel model = createCompletionModel("api-key", "test-endpoint"); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, isStreaming); - MistralChatCompletionRequestEntity entity = new MistralChatCompletionRequestEntity(unifiedChatInput, model); + MistralChatCompletionRequestEntity entity = new MistralChatCompletionRequestEntity(unifiedChatInput, modelId); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "test-endpoint", - "n": 1, - "stream": true - } - """; assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); } }