diff --git a/docs/changelog/129146.yaml b/docs/changelog/129146.yaml new file mode 100644 index 0000000000000..6303e8ab089cd --- /dev/null +++ b/docs/changelog/129146.yaml @@ -0,0 +1,5 @@ +pr: 129146 +summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index b8494c5b76012..486804a1f57ee 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -328,7 +328,7 @@ static TransportVersion def(int id) { public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00); public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00); public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00); - + public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index b96c94db438a7..471ed789ff69d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -151,7 +151,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "completion_test_service", "hugging_face", "amazon_sagemaker", - "mistral" + "mistral", + "watsonxai" ).toArray() ) ); @@ -169,7 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { "hugging_face", "amazon_sagemaker", "googlevertexai", - "mistral" + "mistral", + "watsonxai" ).toArray() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index b697530e9848b..c347fa1dca4ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; @@ -469,6 +470,13 @@ private static void addIbmWatsonxNamedWritables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java index 2244afc135582..ef8c0b56f33b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java @@ -80,8 +80,8 @@ public DefaultSecretSettings getSecretSettings() { /** * Accepts a visitor to create an executable action. The returned action will not return documents in the response. - * @param visitor _ - * @param taskSettings _ + * @param visitor Interface for creating {@link ExecutableAction} instances for Cohere models. + * @param taskSettings Settings in the request to override the model's defaults * @return the rerank action */ @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..41b82bbf2cd02 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx; + +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; + +/** + * Handles streaming chat completion responses and error parsing for Watsonx inference endpoints. + * Adapts the OpenAI handler to support Watsonx's error schema. + */ +public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String WATSONX_ERROR = "watsonx_error"; + + public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse); + } + + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof IbmWatsonxErrorResponseEntity + ? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxCompletionResponseHandler.java new file mode 100644 index 0000000000000..32f3fef4ed070 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxCompletionResponseHandler.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "IBM watsonx completions"). + * @param parseFunction The function to parse the response. + */ + public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java index b7c679d3cda54..c7a556321148e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java @@ -35,7 +35,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager private static final ResponseHandler HANDLER = createEmbeddingsHandler(); private static ResponseHandler createEmbeddingsHandler() { - return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse); + return new IbmWatsonxResponseHandler("IBM watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse); } private final IbmWatsonxEmbeddingsModel model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java index dbdff0316199d..4e8672e6b7361 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java @@ -7,18 +7,19 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Map; import java.util.Objects; -public abstract class IbmWatsonxModel extends Model { +public abstract class IbmWatsonxModel extends RateLimitGroupingModel { private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings; @@ -49,4 +50,14 @@ public IbmWatsonxModel(IbmWatsonxModel model, TaskSettings taskSettings) { public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(this.rateLimitServiceSettings); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return this.rateLimitServiceSettings().rateLimitSettings(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java index d62722a2b593b..509d50f81f4a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java @@ -31,7 +31,7 @@ public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager { private static ResponseHandler createIbmWatsonxResponseHandler() { return new IbmWatsonxResponseHandler( - "ibm watsonx rerank", + "IBM watsonx rerank", (request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 7dfb0002bb062..9bc63be1f9e7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -30,7 +30,10 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -40,14 +43,18 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -56,7 +63,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; @@ -66,8 +72,16 @@ public class IbmWatsonxService extends SenderService { public static final String NAME = "watsonxai"; - private static final String SERVICE_NAME = "IBM Watsonx"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING); + private static final String SERVICE_NAME = "IBM watsonx"; + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler( + "IBM watsonx chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -148,6 +162,14 @@ private static IbmWatsonxModel createModel( secretSettings, context ); + case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -236,6 +258,11 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; } + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + } + @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) { @@ -291,7 +318,24 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof IbmWatsonxChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model; + var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); } @Override @@ -331,7 +375,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( API_VERSION, - new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM Watsonx API version ID to use.") + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM watsonx API version ID to use.") .setLabel("API Version") .setRequired(true) .setSensitive(false) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionCreator.java index 46b84689ee3bf..e70358561164e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionCreator.java @@ -7,26 +7,48 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx.action; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRerankRequestManager; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +/** + * IbmWatsonxActionCreator is responsible for creating executable actions for various models. + * It implements the IbmWatsonxActionVisitor interface to provide specific implementations. + */ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor { private final Sender sender; private final ServiceComponents serviceComponents; + static final String COMPLETION_REQUEST_TYPE = "IBM watsonx completions"; + static final String USER_ROLE = "user"; + static final ResponseHandler COMPLETION_HANDLER = new IbmWatsonxCompletionResponseHandler( + COMPLETION_REQUEST_TYPE, + OpenAiChatCompletionResponseEntity::fromResponse + ); + public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); @@ -34,7 +56,7 @@ public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponent @Override public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map taskSettings) { - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings"); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings"); return new SenderExecutableAction( sender, getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()), @@ -46,10 +68,24 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map taskSettings) { var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings); var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Ibm Watsonx rerank"); + var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } + @Override + public ExecutableAction create(IbmWatsonxChatCompletionModel chatCompletionModel) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + chatCompletionModel, + COMPLETION_HANDLER, + inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel), + ChatCompletionInput.class + ); + + var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_REQUEST_TYPE); + } + protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager( IbmWatsonxEmbeddingsModel model, Truncator truncator, @@ -57,4 +93,15 @@ protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager( ) { return new IbmWatsonxEmbeddingsRequestManager(model, truncator, threadPool); } + + /** + * Builds an error message for IBM watsonx actions. + * + * @param requestType The type of request (e.g. COMPLETION, EMBEDDING, RERANK). + * @param inferenceId The ID of the inference entity. + * @return A formatted error message. + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format("Failed to send IBM watsonx %s request from inference entity id [%s]", requestType.toString(), inferenceId); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionVisitor.java index 64e05769f4a8d..3ef2ac26bc54c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionVisitor.java @@ -8,13 +8,42 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; import java.util.Map; +/** + * Interface for creating {@link ExecutableAction} instances for IBM watsonx models. + *

+ * This interface is used to create {@link ExecutableAction} instances for different types of IBM watsonx models, such as + * {@link IbmWatsonxEmbeddingsModel} and {@link IbmWatsonxRerankModel} and {@link IbmWatsonxChatCompletionModel}. + */ public interface IbmWatsonxActionVisitor { + + /** + * Creates an {@link ExecutableAction} for the given {@link IbmWatsonxEmbeddingsModel}. + * + * @param model The model to create the action for. + * @param taskSettings The task settings to use. + * @return An {@link ExecutableAction} for the given model. + */ ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map taskSettings); + /** + * Creates an {@link ExecutableAction} for the given {@link IbmWatsonxRerankModel}. + * + * @param model The model to create the action for. + * @return An {@link ExecutableAction} for the given model. + */ ExecutableAction create(IbmWatsonxRerankModel model, Map taskSettings); + + /** + * Creates an {@link ExecutableAction} for the given {@link IbmWatsonxChatCompletionModel}. + * + * @param model The model to create the action for. + * @return An {@link ExecutableAction} for the given model. + */ + ExecutableAction create(IbmWatsonxChatCompletionModel model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModel.java new file mode 100644 index 0000000000000..c10759aa7b6db --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModel.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.COMPLETIONS; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.ML; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.TEXT; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.V1; + +public class IbmWatsonxChatCompletionModel extends IbmWatsonxModel { + + /** + * Constructor for IbmWatsonxChatCompletionModel. + * + * @param inferenceEntityId The unique identifier for the inference entity. + * @param taskType The type of task this model is designed for. + * @param service The name of the service this model belongs to. + * @param serviceSettings The settings specific to the Ibm Granite chat completion service. + * @param secrets The secrets required for accessing the service. + * @param context The context for parsing configuration settings. + */ + public IbmWatsonxChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + IbmWatsonxChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Creates a new IbmWatsonxChatCompletionModel with overridden service settings. + * + * @param model The original IbmWatsonxChatCompletionModel. + * @param request The UnifiedCompletionRequest containing the model override. + * @return A new IbmWatsonxChatCompletionModel with the overridden model ID. + */ + public static IbmWatsonxChatCompletionModel of(IbmWatsonxChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new IbmWatsonxChatCompletionServiceSettings( + originalModelServiceSettings.uri(), + originalModelServiceSettings.apiVersion(), + request.model(), + originalModelServiceSettings.projectId(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new IbmWatsonxChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + // should only be used for testing + IbmWatsonxChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + IbmWatsonxChatCompletionServiceSettings serviceSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + + @Override + public IbmWatsonxChatCompletionServiceSettings getServiceSettings() { + return (IbmWatsonxChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + public URI uri() { + URI uri; + try { + uri = buildUri(this.getServiceSettings().uri().toString(), this.getServiceSettings().apiVersion()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + + return uri; + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models. + * @return the completion action + */ + public ExecutableAction accept(IbmWatsonxActionVisitor visitor, Map taskSettings) { + return visitor.create(this); + } + + public static URI buildUri(String uri, String apiVersion) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(uri) + .setPathSegments(ML, V1, TEXT, COMPLETIONS) + .setParameter("version", apiVersion) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..0492a626787cf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettings.java @@ -0,0 +1,193 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID; + +public class IbmWatsonxChatCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + IbmWatsonxRateLimitServiceSettings { + public static final String NAME = "ibm_watsonx_completion_service_settings"; + + /** + * Rate limits are defined at + * Watson Machine Learning plans. + * For the Lite plan, the limit is 120 requests per minute. + */ + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120); + + public static IbmWatsonxChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String projectId = extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + IbmWatsonxService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, rateLimitSettings); + } + + private final URI uri; + + private final String apiVersion; + + private final String modelId; + + private final String projectId; + + private final RateLimitSettings rateLimitSettings; + + public IbmWatsonxChatCompletionServiceSettings( + URI uri, + String apiVersion, + String modelId, + String projectId, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.uri = uri; + this.apiVersion = apiVersion; + this.projectId = projectId; + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public IbmWatsonxChatCompletionServiceSettings(StreamInput in) throws IOException { + this.uri = createUri(in.readString()); + this.apiVersion = in.readString(); + this.modelId = in.readString(); + this.projectId = in.readString(); + this.rateLimitSettings = new RateLimitSettings(in); + + } + + public URI uri() { + return uri; + } + + public String apiVersion() { + return apiVersion; + } + + @Override + public String modelId() { + return modelId; + } + + public String projectId() { + return projectId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(URL, uri.toString()); + + builder.field(API_VERSION, apiVersion); + + builder.field(MODEL_ID, modelId); + + builder.field(PROJECT_ID, projectId); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(uri.toString()); + out.writeString(apiVersion); + + out.writeString(modelId); + out.writeString(projectId); + + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + IbmWatsonxChatCompletionServiceSettings that = (IbmWatsonxChatCompletionServiceSettings) object; + return Objects.equals(uri, that.uri) + && Objects.equals(apiVersion, that.apiVersion) + && Objects.equals(modelId, that.modelId) + && Objects.equals(projectId, that.projectId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, apiVersion, modelId, projectId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java index 3a9625aef31c7..404df1c9cb4b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java @@ -52,7 +52,7 @@ public class IbmWatsonxEmbeddingsServiceSettings extends FilteredXContentObject /** * Rate limits are defined at * Watson Machine Learning plans. - * For Lite plan, you've 120 requests per minute. + * For the Lite plan, the limit is 120 requests per minute. */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequest.java new file mode 100644 index 0000000000000..b3d33f1cf211b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequest.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class IbmWatsonxChatCompletionRequest implements IbmWatsonxRequest { + private final IbmWatsonxChatCompletionModel model; + private final UnifiedChatInput chatInput; + + public IbmWatsonxChatCompletionRequest(UnifiedChatInput chatInput, IbmWatsonxChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new IbmWatsonxChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + decorateWithAuth(httpPost); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + public void decorateWithAuth(HttpPost httpPost) { + IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); + } + + @Override + public Request truncate() { + // No truncation for IBM watsonx chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for IBM watsonx chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..9f4cb9575f3cb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.request; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +/** + * IbmWatsonxChatCompletionRequestEntity is responsible for creating the request entity for Watsonx chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class IbmWatsonxChatCompletionRequestEntity implements ToXContentObject { + + private final IbmWatsonxChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + private static final String PROJECT_ID_FIELD = "project_id"; + + public IbmWatsonxChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, IbmWatsonxChatCompletionModel model) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PROJECT_ID_FIELD, model.getServiceSettings().projectId()); + unifiedRequestEntity.toXContent( + builder, + UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params) + ); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxUtils.java index 7c4dc526ca509..856c91f509b26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxUtils.java @@ -14,6 +14,7 @@ public class IbmWatsonxUtils { public static final String TEXT = "text"; public static final String EMBEDDINGS = "embeddings"; public static final String RERANKS = "reranks"; + public static final String COMPLETIONS = "chat"; private IbmWatsonxUtils() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java index 15dd648c2fa1a..7026bd17db5db 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java @@ -100,8 +100,8 @@ public URI uri() { /** * Accepts a visitor to create an executable action. The returned action will not return documents in the response. - * @param visitor _ - * @param taskSettings _ + * @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models. + * @param taskSettings Settings in the request to override the model's defaults * @return the rerank action */ @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java index 969622f9ba54f..b4b183736e841 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java @@ -41,7 +41,7 @@ public class IbmWatsonxRerankServiceSettings extends FilteredXContentObject impl /** * Rate limits are defined at * Watson Machine Learning plans. - * For Lite plan, you've 120 requests per minute. + * For the Lite plan, the limit is 120 requests per minute. */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java index a1e90f7c0ee24..4fda9d5661a2c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java @@ -28,7 +28,7 @@ public class IbmWatsonxEmbeddingsResponseEntity { - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM Watsonx embeddings response"; + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response"; public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxRankedResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxRankedResponseEntity.java index 729e6ef980350..e513d2caab781 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxRankedResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxRankedResponseEntity.java @@ -32,7 +32,7 @@ public class IbmWatsonxRankedResponseEntity { private static final Logger logger = LogManager.getLogger(IbmWatsonxRankedResponseEntity.class); /** - * Parses the Ibm Watsonx ranked response. + * Parses the IBM watsonx ranked response. * * For a request like: * "model": "rerank-english-v2.0", @@ -71,7 +71,7 @@ public class IbmWatsonxRankedResponseEntity { * ], * } * - * @param response the http response from ibm watsonx + * @param response the http response from IBM watsonx * @return the parsed response * @throws IOException if there is an error parsing the response */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java index a1fed50753627..a3404ed308835 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java @@ -84,8 +84,8 @@ public DefaultSecretSettings getSecretSettings() { /** * Accepts a visitor to create an executable action. The returned action will not return documents in the response. - * @param visitor _ - * @param taskSettings _ + * @param visitor Interface for creating {@link ExecutableAction} instances for Jina AI models. + * @param taskSettings Settings in the request to override the model's defaults * @return the rerank action */ @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java index 7c96ccc2c592c..fbf842f4fb789 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java @@ -37,8 +37,8 @@ public class MistralActionCreator implements MistralActionVisitor { public static final String COMPLETION_ERROR_PREFIX = "Mistral completions"; - static final String USER_ROLE = "user"; - static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler( + public static final String USER_ROLE = "user"; + public static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler( "mistral completions", OpenAiChatCompletionResponseEntity::fromResponse ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java index 7e58843c78f18..98ace66f74ad4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java @@ -109,8 +109,8 @@ public DefaultSecretSettings getSecretSettings() { /** * Accepts a visitor to create an executable action. The returned action will not return documents in the response. - * @param visitor _ - * @param taskSettings _ + * @param visitor Interface for creating {@link ExecutableAction} instances for Voyage AI models. + * @param taskSettings Settings in the request to override the model's defaults * @return the rerank action */ @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java new file mode 100644 index 0000000000000..ddce5d96c63cf --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public abstract class ChatCompletionActionTests extends ESTestCase { + protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + protected final MockWebServer webServer = new MockWebServer(); + protected HttpClientManager clientManager; + protected ThreadPool threadPool; + + protected abstract ExecutableAction createAction(String url, Sender sender) throws URISyntaxException; + + protected abstract String getOneInputError(); + + protected abstract String getFailedToSendError(); + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ThrowsElasticsearchException() throws URISyntaxException { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() throws URISyntaxException { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(getFailedToSendError())); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson())); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(getOneInputError())); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + protected String getResponseJson() { + return """ + { + "id": "9d80f26810ac4e9582f927fcf0512ec7", + "object": "chat.completion", + "created": 1748596419, + "model": "modelId", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "tool_calls": null, + "content": "result content" + }, + "finish_reason": "length", + "logprobs": null + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 11, + "completion_tokens": 1 + } + } + """; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 35dbcdd6aa99f..3295ecfd4ece5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -918,8 +918,8 @@ public void testGetConfiguration() throws Exception { String content = XContentHelper.stripWhitespace(""" { "service": "watsonxai", - "name": "IBM Watsonx", - "task_types": ["text_embedding"], + "name": "IBM watsonx", + "task_types": ["text_embedding", "completion", "chat_completion"], "configurations": { "project_id": { "description": "", @@ -928,7 +928,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -937,16 +937,16 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "api_version": { - "description": "The IBM Watsonx API version ID to use.", + "description": "The IBM watsonx API version ID to use.", "label": "API Version", "required": true, "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -964,7 +964,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxChatCompletionActionTests.java new file mode 100644 index 0000000000000..e3f772bbf1ea7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxChatCompletionActionTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest; + +import java.net.URI; +import java.net.URISyntaxException; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.USER_ROLE; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel; + +public class IbmWatsonxChatCompletionActionTests extends ChatCompletionActionTests { + public static final URI TEST_URI = URI.create("abc.com"); + + protected ExecutableAction createAction(String url, Sender sender) throws URISyntaxException { + var model = createModel(TEST_URI, randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8)); + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + var errorMessage = constructFailedToSendRequestMessage("watsonx chat completions"); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "watsonx chat completions"); + } + + protected String getFailedToSendError() { + return "Failed to send watsonx chat completions request. Cause: failed"; + } + + protected String getOneInputError() { + return "watsonx chat completions only accepts 1 input"; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index 9376e4da76261..f5f4f8bf2a052 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -180,7 +180,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed")); + assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed")); } public void testExecute_ThrowsException() { @@ -204,7 +204,7 @@ public void testExecute_ThrowsException() { var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed")); + assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed")); } private ExecutableAction createAction( @@ -218,7 +218,7 @@ private ExecutableAction createAction( ) { var model = createModel(modelName, projectId, uri, apiVersion, apiKey, url); var requestManager = new IbmWatsonxEmbeddingsRequestManagerWithoutAuth(model, TruncatorTests.createTruncator(), threadPool); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM Watsonx embeddings"); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings"); return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModelTests.java new file mode 100644 index 0000000000000..683912072714b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModelTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class IbmWatsonxChatCompletionModelTests extends ESTestCase { + private static final URI TEST_URI = URI.create("abc.com"); + + public static IbmWatsonxChatCompletionModel createModel(URI uri, String apiVersion, String modelId, String projectId, String apiKey) + throws URISyntaxException { + return new IbmWatsonxChatCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() throws URISyntaxException { + var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() throws URISyntaxException { + var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() throws URISyntaxException { + var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() throws URISyntaxException { + var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("modelId")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..88984b6e2f427 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettingsTests.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class IbmWatsonxChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final URI TEST_URI = URI.create("abc.com"); + + private static IbmWatsonxChatCompletionServiceSettings createRandom() { + return new IbmWatsonxChatCompletionServiceSettings( + TEST_URI, + randomAlphaOfLength(8), + randomAlphaOfLength(8), + randomAlphaOfLength(8), + randomFrom(RateLimitSettingsTests.createRandom(), null) + ); + } + + private IbmWatsonxChatCompletionServiceSettings getServiceSettings(Map map) { + return IbmWatsonxChatCompletionServiceSettings.fromMap(new HashMap<>(map), ConfigurationParseContext.PERSISTENT); + } + + public void testFromMap_WithAllParameters_CreatesSettingsCorrectly() { + var model = randomAlphaOfLength(8); + var projectId = randomAlphaOfLength(8); + var apiVersion = randomAlphaOfLength(8); + + var serviceSettings = getServiceSettings( + Map.of( + ServiceFields.URL, + TEST_URI.toString(), + IbmWatsonxServiceFields.API_VERSION, + apiVersion, + ServiceFields.MODEL_ID, + model, + IbmWatsonxServiceFields.PROJECT_ID, + projectId + ) + ); + assertThat(serviceSettings, is(new IbmWatsonxChatCompletionServiceSettings(TEST_URI, apiVersion, model, projectId, null))); + } + + public void testFromMap_Fails_WithoutRequiredParam_Url() { + var ex = expectThrows( + ValidationException.class, + () -> getServiceSettings( + Map.of( + IbmWatsonxServiceFields.API_VERSION, + randomAlphaOfLength(8), + ServiceFields.MODEL_ID, + randomAlphaOfLength(8), + IbmWatsonxServiceFields.PROJECT_ID, + randomAlphaOfLength(8) + ) + ) + ); + assertThat(ex.getMessage(), equalTo(generateErrorMessage("url"))); + } + + public void testFromMap_Fails_WithoutRequiredParam_ApiVersion() { + var ex = expectThrows( + ValidationException.class, + () -> getServiceSettings( + Map.of( + ServiceFields.URL, + TEST_URI.toString(), + ServiceFields.MODEL_ID, + randomAlphaOfLength(8), + IbmWatsonxServiceFields.PROJECT_ID, + randomAlphaOfLength(8) + ) + ) + ); + assertThat(ex.getMessage(), equalTo(generateErrorMessage("api_version"))); + } + + public void testFromMap_Fails_WithoutRequiredParam_ModelId() { + var ex = expectThrows( + ValidationException.class, + () -> getServiceSettings( + Map.of( + ServiceFields.URL, + TEST_URI.toString(), + IbmWatsonxServiceFields.API_VERSION, + randomAlphaOfLength(8), + IbmWatsonxServiceFields.PROJECT_ID, + randomAlphaOfLength(8) + ) + ) + ); + assertThat(ex.getMessage(), equalTo(generateErrorMessage("model_id"))); + } + + public void testFromMap_Fails_WithoutRequiredParam_ProjectId() { + var ex = expectThrows( + ValidationException.class, + () -> getServiceSettings( + Map.of( + ServiceFields.URL, + TEST_URI.toString(), + IbmWatsonxServiceFields.API_VERSION, + randomAlphaOfLength(8), + ServiceFields.MODEL_ID, + randomAlphaOfLength(8) + ) + ) + ); + assertThat(ex.getMessage(), equalTo(generateErrorMessage("project_id"))); + } + + private String generateErrorMessage(String field) { + return "Validation Failed: 1: [service_settings] does not contain the required setting [" + field + "];"; + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new IbmWatsonxChatCompletionServiceSettings(TEST_URI, "2024-05-02", "model", "project_id", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "url":"abc.com", + "api_version":"2024-05-02", + "model_id":"model", + "project_id":"project_id", + "rate_limit": { + "requests_per_minute":120 + } + }""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return IbmWatsonxChatCompletionServiceSettings::new; + } + + @Override + protected IbmWatsonxChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected IbmWatsonxChatCompletionServiceSettings mutateInstance(IbmWatsonxChatCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, IbmWatsonxChatCompletionServiceSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..39b607db9cacb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntityTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; + +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel; + +public class IbmWatsonxChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException, URISyntaxException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("test content"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + IbmWatsonxChatCompletionModel model = createModel(new URI("abc.com"), "apiVersion", "modelId", "projectId", "apiKey"); + + IbmWatsonxChatCompletionRequestEntity entity = new IbmWatsonxChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String expectedJson = """ + { + "project_id": "projectId", + "messages": [ + { + "content": "test content", + "role": "user" + } + ], + "model": "modelId", + "n": 1, + "stream": true + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestTests.java new file mode 100644 index 0000000000000..c62221c25dccd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestTests.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class IbmWatsonxChatCompletionRequestTests extends ESTestCase { + private static final String AUTH_HEADER_VALUE = "foo"; + private static final String API_COMPLETIONS_PATH = "https://abc.com/ml/v1/text/chat?version=apiVersion"; + + public void testCreateRequest_WithStreaming() throws IOException, URISyntaxException { + assertCreateRequestWithStreaming(true); + } + + public void testCreateRequest_WithNoStreaming() throws IOException, URISyntaxException { + assertCreateRequestWithStreaming(false); + } + + public void testTruncate_DoesNotReduceInputTextSize() throws IOException, URISyntaxException { + String input = randomAlphaOfLength(5); + String model = randomAlphaOfLength(5); + + var request = createRequest(randomAlphaOfLength(5), input, model, true); + var truncatedRequest = request.truncate(); + assertThat(request.getURI().toString(), is(API_COMPLETIONS_PATH)); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(5)); + + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(requestMap.get("model"), is(model)); + assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertNull(requestMap.get("stream_options")); + } + + public void testTruncationInfo_ReturnsNull() throws URISyntaxException { + var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), true); + assertNull(request.getTruncationInfo()); + } + + public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model) + throws URISyntaxException { + return createRequest(apiKey, input, model, false); + } + + public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream) + throws URISyntaxException { + var chatCompletionModel = IbmWatsonxChatCompletionModelTests.createModel( + new URI("abc.com"), + "apiVersion", + model, + randomAlphaOfLength(5), + apiKey + ); + return new IbmWatsonxChatCompletionWithoutAuthRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + + private static class IbmWatsonxChatCompletionWithoutAuthRequest extends IbmWatsonxChatCompletionRequest { + IbmWatsonxChatCompletionWithoutAuthRequest(UnifiedChatInput input, IbmWatsonxChatCompletionModel model) { + super(input, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + } + + private void assertCreateRequestWithStreaming(boolean isStreaming) throws URISyntaxException, IOException { + var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), isStreaming); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(isStreaming)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java index bafe0fd02e75d..db42e9c49e1e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java @@ -112,6 +112,6 @@ public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { ) ); - assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM Watsonx embeddings response")); + assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM watsonx embeddings response")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java index e585092b1f780..4f8e99ecfe8fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java @@ -8,42 +8,28 @@ package org.elasticsearch.xpack.inference.services.mistral.action; import org.apache.http.HttpHeaders; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests; import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; -import org.junit.After; -import org.junit.Before; import java.io.IOException; +import java.net.URISyntaxException; import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; @@ -56,64 +42,15 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; - -public class MistralChatCompletionActionTests extends ESTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - private final MockWebServer webServer = new MockWebServer(); - private ThreadPool threadPool; - private HttpClientManager clientManager; - - @Before - public void init() throws Exception { - webServer.start(); - threadPool = createThreadPool(inferenceUtilityPool()); - clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - } - - @After - public void shutdown() throws IOException { - clientManager.close(); - terminate(threadPool); - webServer.close(); - } - public void testExecute_ReturnsSuccessfulResponse() throws IOException { +public class MistralChatCompletionActionTests extends ChatCompletionActionTests { + public void testExecute_ReturnsSuccessfulResponse() throws IOException, URISyntaxException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { sender.start(); - String responseJson = """ - { - "id": "9d80f26810ac4e9582f927fcf0512ec7", - "object": "chat.completion", - "created": 1748596419, - "model": "mistral-small-latest", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "tool_calls": null, - "content": "result content" - }, - "finish_reason": "length", - "logprobs": null - } - ], - "usage": { - "prompt_tokens": 10, - "total_tokens": 11, - "completion_tokens": 1 - } - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson())); var action = createAction(getUrl(webServer), sender); @@ -140,87 +77,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { } } - public void testExecute_ThrowsElasticsearchException() { - var sender = mock(Sender.class); - doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); - - var action = createAction(getUrl(webServer), sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("failed")); - } - - public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { - var sender = mock(Sender.class); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(new IllegalStateException("failed")); - - return Void.TYPE; - }).when(sender).send(any(), any(), any(), any()); - - var action = createAction(getUrl(webServer), sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed")); - } - - public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "id": "9d80f26810ac4e9582f927fcf0512ec7", - "object": "chat.completion", - "created": 1748596419, - "model": "mistral-small-latest", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "tool_calls": null, - "content": "result content" - }, - "finish_reason": "length", - "logprobs": null - } - ], - "usage": { - "prompt_tokens": 10, - "total_tokens": 11, - "completion_tokens": 1 - } - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var action = createAction(getUrl(webServer), sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input")); - assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); - } - } - - private ExecutableAction createAction(String url, Sender sender) { + protected ExecutableAction createAction(String url, Sender sender) { var model = createCompletionModel("secret", "model"); model.setURI(url); var manager = new GenericRequestManager<>( @@ -233,4 +90,12 @@ private ExecutableAction createAction(String url, Sender sender) { var errorMessage = constructFailedToSendRequestMessage("mistral chat completions"); return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions"); } + + protected String getFailedToSendError() { + return "Failed to send mistral chat completions request. Cause: failed"; + } + + protected String getOneInputError() { + return "mistral chat completions only accepts 1 input"; + } }