diff --git a/docs/changelog/134080.yaml b/docs/changelog/134080.yaml new file mode 100644 index 0000000000000..d1c22d59c0a56 --- /dev/null +++ b/docs/changelog/134080.yaml @@ -0,0 +1,5 @@ +pr: 134080 +summary: Added Google Model Garden Anthropic Completion and Chat Completion support to the Inference Plugin +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index b6f724e69d40f..91543710d695e 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -58,11 +58,11 @@ public record UnifiedCompletionRequest( private static final String ROLE_FIELD = "role"; private static final String CONTENT_FIELD = "content"; private static final String STOP_FIELD = "stop"; - private static final String TEMPERATURE_FIELD = "temperature"; - private static final String TOOL_CHOICE_FIELD = "tool_choice"; - private static final String TOOL_FIELD = "tools"; + public static final String TEMPERATURE_FIELD = "temperature"; + public static final String TOOL_CHOICE_FIELD = "tool_choice"; + public static final String TOOL_FIELD = "tools"; private static final String TEXT_FIELD = "text"; - private static final String TYPE_FIELD = "type"; + public static final String TYPE_FIELD = "type"; private static final String MODEL_FIELD = "model"; private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; private static final String MAX_TOKENS_FIELD = "max_tokens"; diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_google_model_garden_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_google_model_garden_added.csv new file mode 100644 index 0000000000000..8e1566a24a944 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_google_model_garden_added.csv @@ -0,0 +1 @@ +9179000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 34a3bd3790d5b..2f27ba13c86cd 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -no_matching_project_exception,9178000 +ml_inference_google_model_garden_added,9179000 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 874625c93a528..b8ff560ced006 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -314,6 +314,18 @@ public static URI extractUri(Map map, String fieldName, Validati return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); } + /** + * Extracts an optional URI from the map. If the field is not present, null is returned. If the field is present but invalid, + * @param map the map to extract the URI from + * @param fieldName the field name to extract + * @param validationException the validation exception to add errors to + * @return the extracted URI or null if not present + */ + public static URI extractOptionalUri(Map map, String fieldName, ValidationException validationException) { + String parsedUrl = extractOptionalString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + } + public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { return createOptionalUri(url); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..8fd6a755c8772 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionResponseHandler.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ChatCompletionErrorResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; +import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity; + +import java.util.concurrent.Flow; + +/** + * Handles streaming chat completion responses and error parsing for Anthropic inference endpoints. + * Adapts the AnthropicResponseHandler to support chat completion schema. + */ +public class AnthropicChatCompletionResponseHandler extends AnthropicResponseHandler { + private static final String ANTHROPIC_ERROR = "anthropic_error"; + private static final UnifiedChatCompletionErrorParserContract ANTHROPIC_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(ANTHROPIC_ERROR); + + private final ChatCompletionErrorResponseHandler chatCompletionErrorResponseHandler; + + public AnthropicChatCompletionResponseHandler(String requestType) { + this(requestType, AnthropicChatCompletionResponseEntity::fromResponse); + } + + private AnthropicChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, true); + this.chatCompletionErrorResponseHandler = new ChatCompletionErrorResponseHandler(ANTHROPIC_ERROR_PARSER); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var anthropicProcessor = new AnthropicChatCompletionStreamingProcessor( + (m, e) -> chatCompletionErrorResponseHandler.buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(anthropicProcessor); + return new StreamingUnifiedChatCompletionResults(anthropicProcessor); + } + + @Override + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result) { + return chatCompletionErrorResponseHandler.buildChatCompletionError(message, request, result); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionStreamingProcessor.java new file mode 100644 index 0000000000000..f70d654054e6f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionStreamingProcessor.java @@ -0,0 +1,299 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseFieldsValue; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +/** + * Chat Completions Streaming Processor for Anthropic provider + */ +public class AnthropicChatCompletionStreamingProcessor extends DelegatingProcessor< + Deque, + StreamingUnifiedChatCompletionResults.Results> { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Anthropic chat completions response"; + private static final Logger logger = LogManager.getLogger(AnthropicChatCompletionStreamingProcessor.class); + + // Field names + public static final String ROLE_FIELD = "role"; + public static final String INDEX_FIELD = "index"; + public static final String TYPE_FIELD = "type"; + public static final String MODEL_FIELD = "model"; + public static final String ID_FIELD = "id"; + public static final String NAME_FIELD = "name"; + public static final String INPUT_TOKENS_FIELD = "input_tokens"; + public static final String OUTPUT_TOKENS_FIELD = "output_tokens"; + public static final String STOP_REASON_FIELD = "stop_reason"; + public static final String TEXT_FIELD = "text"; + public static final String INPUT_FIELD = "input"; + public static final String PARTIAL_JSON_FIELD = "partial_json"; + + // Event types + public static final String MESSAGE_DELTA_EVENT_TYPE = "message_delta"; + public static final String CONTENT_BLOCK_START_EVENT_TYPE = "content_block_start"; + public static final String MESSAGE_START_EVENT_TYPE = "message_start"; + public static final String VERTEX_EVENT_EVENT_TYPE = "vertex_event"; + public static final String PING_EVENT_TYPE = "ping"; + public static final String CONTENT_BLOCK_STOP_EVENT_TYPE = "content_block_stop"; + public static final String CONTENT_BLOCK_DELTA_EVENT_TYPE = "content_block_delta"; + public static final String MESSAGE_STOP_EVENT_TYPE = "message_stop"; + public static final String ERROR_TYPE = "error"; + + // Content block types + public static final String TEXT_DELTA_TYPE = "text_delta"; + public static final String INPUT_JSON_DELTA_TYPE = "input_json_delta"; + public static final String TOOL_USE_TYPE = "tool_use"; + public static final String TEXT_TYPE = "text"; + + private final BiFunction errorParser; + + public AnthropicChatCompletionStreamingProcessor(BiFunction errorParser) { + this.errorParser = errorParser; + } + + @Override + protected void next(Deque item) throws Exception { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = new ArrayDeque(item.size()); + + for (var event : item) { + if (ERROR_TYPE.equals(event.type()) && event.hasData()) { + throw errorParser.apply(event.data(), null); + } else if (event.hasData()) { + try { + var delta = parse(parserConfig, event); + delta.forEach(results::offer); + } catch (Exception e) { + logger.warn("Failed to parse event from inference provider: {}", event); + throw errorParser.apply(event.data(), e); + } + } + } + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } + } + + /** + * Parse a single ServerSentEvent into zero or more ChatCompletionChunk + * @param parserConfig the parser configuration + * @param event the server sent event + * @return a stream of ChatCompletionChunk + * @throws IOException if parsing fails + */ + public static Stream parse( + XContentParserConfiguration parserConfig, + ServerSentEvent event + ) throws IOException { + // Handle known event types + switch (event.type()) { + case VERTEX_EVENT_EVENT_TYPE, PING_EVENT_TYPE, CONTENT_BLOCK_STOP_EVENT_TYPE: + // No content to parse, just skip + logger.debug("Skipping event type [{}] for line [{}].", event.type(), event.data()); + return Stream.empty(); + case MESSAGE_START_EVENT_TYPE: + return parseMessageStart(parserConfig, event.data()); + case CONTENT_BLOCK_START_EVENT_TYPE: + return parseContentBlockStart(parserConfig, event.data()); + case CONTENT_BLOCK_DELTA_EVENT_TYPE: + return parseContentBlockDelta(parserConfig, event.data()); + case MESSAGE_DELTA_EVENT_TYPE: + return parseMessageDelta(parserConfig, event.data()); + case MESSAGE_STOP_EVENT_TYPE: + return Stream.empty(); + case null, default: + logger.debug("Unknown event type [{}] for line [{}].", event.type(), event.data()); + return Stream.empty(); + } + } + + /** + * Parse a message start event into a ChatCompletionChunk stream + * @param parserConfig the parser configuration + * @param data the event data + * @return a stream of ChatCompletionChunk + * @throws IOException if parsing fails + */ + public static Stream parseMessageStart( + XContentParserConfiguration parserConfig, + String data + ) throws IOException { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { + var id = parseStringField(jsonParser, ID_FIELD); + var role = parseStringField(jsonParser, ROLE_FIELD); + var model = parseStringField(jsonParser, MODEL_FIELD); + var finishReason = parseStringOrNullField(jsonParser, STOP_REASON_FIELD); + var promptTokens = parseNumberField(jsonParser, INPUT_TOKENS_FIELD); + var completionTokens = parseNumberField(jsonParser, OUTPUT_TOKENS_FIELD); + var totalTokens = completionTokens + promptTokens; + + var usage = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(completionTokens, promptTokens, totalTokens); + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, role, null); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, finishReason, 0); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(id, List.of(choice), model, null, usage); + + return Stream.of(chunk); + } + } + + /** + * Parse a content block start event into a ChatCompletionChunk stream + * @param parserConfig the parser configuration + * @param data the event data + * @return a stream of ChatCompletionChunk + * @throws IOException if parsing fails + */ + public static Stream parseContentBlockStart( + XContentParserConfiguration parserConfig, + String data + ) throws IOException { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { + var index = parseNumberField(jsonParser, INDEX_FIELD); + var type = parseStringField(jsonParser, TYPE_FIELD); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta; + if (type.equals(TEXT_TYPE)) { + var text = parseStringField(jsonParser, TEXT_FIELD); + delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(text, null, null, null); + } else if (type.equals(TOOL_USE_TYPE)) { + var id = parseStringField(jsonParser, ID_FIELD); + var name = parseStringField(jsonParser, NAME_FIELD); + var input = parseFieldValue(jsonParser, INPUT_FIELD); + var function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + input != null ? input.toString() : null, + name + ); + var toolCall = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(0, id, function, null); + delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, List.of(toolCall)); + } else { + logger.debug("Unknown content block start type [{}] for line [{}].", type, data); + return Stream.empty(); + } + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + return Stream.of(chunk); + } + } + + /** + * Parse a content block delta event into a ChatCompletionChunk stream + * @param parserConfig the parser configuration + * @param data the event data + * @return a stream of ChatCompletionChunk + * @throws IOException if parsing fails + */ + public static Stream parseContentBlockDelta( + XContentParserConfiguration parserConfig, + String data + ) throws IOException { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { + var index = parseNumberField(jsonParser, INDEX_FIELD); + var type = parseStringField(jsonParser, TYPE_FIELD); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta; + if (type.equals(TEXT_DELTA_TYPE)) { + var text = parseStringField(jsonParser, TEXT_FIELD); + delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(text, null, null, null); + } else if (type.equals(INPUT_JSON_DELTA_TYPE)) { + var partialJson = parseStringField(jsonParser, PARTIAL_JSON_FIELD); + var function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + partialJson, + null + ); + var toolCall = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(0, null, function, null); + delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, List.of(toolCall)); + } else { + logger.debug("Unknown content block delta type [{}] for line [{}].", type, data); + return Stream.empty(); + } + + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index); + var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null); + + return Stream.of(chunk); + } + } + + /** + * Parse a message delta event into a ChatCompletionChunk stream + * @param parserConfig the parser configuration + * @param data the event data + * @return a stream of ChatCompletionChunk + * @throws IOException if parsing fails + */ + public static Stream parseMessageDelta( + XContentParserConfiguration parserConfig, + String data + ) throws IOException { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { + var finishReason = parseStringOrNullField(jsonParser, STOP_REASON_FIELD); + var totalTokens = parseNumberField(jsonParser, OUTPUT_TOKENS_FIELD); + + var chunk = buildChatCompletionChunk(totalTokens, finishReason); + + return Stream.of(chunk); + } + } + + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk buildChatCompletionChunk( + int totalTokens, + String finishReason + ) { + var usage = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(totalTokens, 0, totalTokens); + var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, null), + finishReason, + 0 + ); + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, usage); + } + + private static int parseNumberField(XContentParser jsonParser, String fieldName) throws IOException { + positionParserAtTokenAfterField(jsonParser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, jsonParser.currentToken(), jsonParser); + return jsonParser.intValue(); + } + + private static String parseStringField(XContentParser jsonParser, String fieldName) throws IOException { + positionParserAtTokenAfterField(jsonParser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, jsonParser.currentToken(), jsonParser); + return jsonParser.text(); + } + + private static String parseStringOrNullField(XContentParser jsonParser, String fieldName) throws IOException { + positionParserAtTokenAfterField(jsonParser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE); + return jsonParser.textOrNull(); + } + + private static Object parseFieldValue(XContentParser jsonParser, String fieldName) throws IOException { + positionParserAtTokenAfterField(jsonParser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE); + return parseFieldsValue(jsonParser); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java new file mode 100644 index 0000000000000..9017faf1459fb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import java.util.Locale; + +/** + * Enum representing the supported model garden providers. + */ +public enum GoogleModelGardenProvider { + GOOGLE, + ANTHROPIC; + + public static final String NAME = "google_model_garden_provider"; + + public static GoogleModelGardenProvider fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 41678689e8b9d..19ab67b920f04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -44,11 +45,12 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -92,8 +94,12 @@ public class GoogleVertexAiService extends SenderService implements RerankingInf InputType.INTERNAL_SEARCH ); - public static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( - "Google VertexAI chat completion" + public static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( + "Google Vertex AI chat completion" + ); + + public static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler( + "Google Model Garden Anthropic chat completion" ); @Override @@ -257,20 +263,45 @@ protected void doUnifiedCompletionInfer( listener.onFailure(createInvalidModelException(model)); return; } - var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model; - var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest()); + var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of( + (GoogleVertexAiChatCompletionModel) model, + inputs.getRequest() + ); + try { + var manager = createRequestManager(updatedChatCompletionModel); + var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + action.execute(inputs, timeout, listener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } + } - var manager = new GenericRequestManager<>( + private GenericRequestManager createRequestManager(GoogleVertexAiChatCompletionModel model) { + switch (model.getServiceSettings().provider()) { + case ANTHROPIC -> { + return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER); + } + case GOOGLE -> { + return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER); + } + case null, default -> throw new ElasticsearchException( + "Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() + ); + } + } + + private GenericRequestManager createRequestManagerWithHandler( + GoogleVertexAiChatCompletionModel model, + ResponseHandler responseHandler + ) { + return new GenericRequestManager<>( getServiceComponents().threadPool(), - updatedChatCompletionModel, - COMPLETION_HANDLER, - (unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel), + model, + responseHandler, + unifiedChatInput -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, model), UnifiedChatInput.class ); - - var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); - var action = new SenderExecutableAction(getSender(), manager, errorMessage); - action.execute(inputs, timeout, listener); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java index b0dab5edf6cf0..152c225742320 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java @@ -12,6 +12,9 @@ public class GoogleVertexAiServiceFields { public static final String LOCATION = "location"; public static final String PROJECT_ID = "project_id"; + public static final String URL_SETTING_NAME = "url"; + public static final String STREAMING_URL_SETTING_NAME = "streaming_url"; + public static final String PROVIDER_SETTING_NAME = "provider"; /** * According to https://cloud.google.com/vertex-ai/docs/quotas#text-embedding-limits the limit is `250`. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index b0034f587f363..f06dcdc46a926 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.action; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; @@ -16,13 +17,15 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicResponseHandler; +import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; @@ -38,13 +41,19 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor private final ServiceComponents serviceComponents; - static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( - "Google VertexAI completion", + static final ResponseHandler GOOGLE_VERTEX_AI_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( + "Google Vertex AI completion", GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse, true ); + static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER = new AnthropicResponseHandler( + "Google Model Garden Anthropic completion", + AnthropicChatCompletionResponseEntity::fromResponse, + true + ); + static final String USER_ROLE = "user"; public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { @@ -75,15 +84,35 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings) { var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); + GenericRequestManager manager = createRequestManager(overriddenModel); + + return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); + } + + private GenericRequestManager createRequestManager(GoogleVertexAiChatCompletionModel model) { + switch (model.getServiceSettings().provider()) { + case ANTHROPIC -> { + return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER); + } + case GOOGLE -> { + return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_COMPLETION_HANDLER); + } + case null, default -> throw new ElasticsearchException( + "Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() + ); + } + } - var manager = new GenericRequestManager<>( + private GenericRequestManager createRequestManagerWithHandler( + GoogleVertexAiChatCompletionModel overriddenModel, + ResponseHandler responseHandler + ) { + return new GenericRequestManager<>( serviceComponents.threadPool(), overriddenModel, - CHAT_COMPLETION_HANDLER, + responseHandler, inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel), ChatCompletionInput.class ); - - return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java index f34472a14157b..8969102b24e31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -65,8 +65,24 @@ public GoogleVertexAiChatCompletionModel( serviceSettings ); try { - this.streamingURI = buildUriStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); - this.nonStreamingUri = buildUriNonStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + var uri = serviceSettings.uri(); + var streamingUri = serviceSettings.streamingUri(); + // For Google Model Garden uri or streamingUri must be set. If not - location, projectId and modelId must be set + if (uri != null || streamingUri != null) { + // If both uris are provided, each will be used as-is (non-streaming vs. streaming). + // If only one is provided, it will be reused for both non-streaming and streaming requests. + // Some providers require both (e.g. Anthropic, Mistral, Ai21). + // Some providers work fine with a single URL (e.g. Meta, Hugging Face). + this.nonStreamingUri = Objects.requireNonNullElse(uri, streamingUri); + this.streamingURI = Objects.requireNonNullElse(streamingUri, uri); + } else { + // If neither uri nor streamingUri is provided, build them from location, projectId, and modelId. + var location = serviceSettings.location(); + var projectId = serviceSettings.projectId(); + var model = serviceSettings.modelId(); + this.streamingURI = buildUriStreaming(location, projectId, model); + this.nonStreamingUri = buildUriNonStreaming(location, projectId, model); + } } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -86,7 +102,10 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM var newServiceSettings = new GoogleVertexAiChatCompletionServiceSettings( originalModelServiceSettings.projectId(), originalModelServiceSettings.location(), - Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + request.model() != null ? request.model() : originalModelServiceSettings.modelId(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.streamingUri(), + originalModelServiceSettings.provider(), originalModelServiceSettings.rateLimitSettings() ); @@ -112,11 +131,12 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM return model; } - var requestTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettingsMap); - if (requestTaskSettings.isEmpty() || model.getTaskSettings().equals(requestTaskSettings)) { + var newTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettingsMap); + if (newTaskSettings.isEmpty() || model.getTaskSettings().equals(newTaskSettings)) { return model; } - var combinedTaskSettings = GoogleVertexAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings); + + var combinedTaskSettings = GoogleVertexAiChatCompletionTaskSettings.of(model.getTaskSettings(), newTaskSettings); return new GoogleVertexAiChatCompletionModel(model, combinedTaskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java index 7e27a4baf87ce..280976ea181c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java @@ -20,19 +20,31 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleModelGardenProvider; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; +import java.net.URI; +import java.util.EnumSet; +import java.util.Locale; import java.util.Map; import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROVIDER_SETTING_NAME; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.STREAMING_URL_SETTING_NAME; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.URL_SETTING_NAME; +/** + * Settings for the Google Vertex AI chat completion service. + * This class contains the settings required to configure a Google Vertex AI chat completion service. + */ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings, @@ -44,20 +56,69 @@ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXConten private final String modelId; private final String projectId; + private final URI uri; + private final URI streamingUri; + private final GoogleModelGardenProvider provider; + private final RateLimitSettings rateLimitSettings; // https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000); public GoogleVertexAiChatCompletionServiceSettings(StreamInput in) throws IOException { - this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in)); + var version = in.getTransportVersion(); + String projectIdFromStreamInput; + String locationFromStreamInput; + String modelIdFromStreamInput; + URI uriFromStreamInput = null; + URI streamingUriFromStreamInput = null; + GoogleModelGardenProvider providerFromStreamInput = null; + + if (GoogleVertexAiUtils.supportsModelGarden(version)) { + projectIdFromStreamInput = in.readOptionalString(); + locationFromStreamInput = in.readOptionalString(); + modelIdFromStreamInput = in.readOptionalString(); + uriFromStreamInput = ServiceUtils.createOptionalUri(in.readOptionalString()); + streamingUriFromStreamInput = ServiceUtils.createOptionalUri(in.readOptionalString()); + providerFromStreamInput = in.readOptionalEnum(GoogleModelGardenProvider.class); + } else { + projectIdFromStreamInput = in.readString(); + locationFromStreamInput = in.readString(); + modelIdFromStreamInput = in.readString(); + } + RateLimitSettings rateLimitSettingsFromStreamInput = new RateLimitSettings(in); + + this.projectId = Strings.isNullOrEmpty(projectIdFromStreamInput) ? null : projectIdFromStreamInput; + this.location = Strings.isNullOrEmpty(locationFromStreamInput) ? null : locationFromStreamInput; + this.modelId = Strings.isNullOrEmpty(modelIdFromStreamInput) ? null : modelIdFromStreamInput; + this.uri = uriFromStreamInput; + this.streamingUri = streamingUriFromStreamInput; + // Default to GOOGLE if not set + this.provider = Objects.requireNonNullElse(providerFromStreamInput, GoogleModelGardenProvider.GOOGLE); + this.rateLimitSettings = rateLimitSettingsFromStreamInput; + } @Override protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.field(PROJECT_ID, projectId); - builder.field(LOCATION, location); - builder.field(MODEL_ID, modelId); + if (Strings.isNullOrEmpty(projectId) == false) { + builder.field(PROJECT_ID, projectId); + } + if (Strings.isNullOrEmpty(location) == false) { + builder.field(LOCATION, location); + } + if (Strings.isNullOrEmpty(modelId) == false) { + builder.field(MODEL_ID, modelId); + } + if (uri != null) { + builder.field(URL_SETTING_NAME, uri.toString()); + } + if (streamingUri != null) { + builder.field(STREAMING_URL_SETTING_NAME, streamingUri.toString()); + } + if (provider != null) { + builder.field(PROVIDER_SETTING_NAME, provider.name()); + } rateLimitSettings.toXContent(builder, params); return builder; } @@ -65,10 +126,22 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); - // Extract required fields - String projectId = ServiceUtils.extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - String location = ServiceUtils.extractRequiredString(map, LOCATION, ModelConfigurations.SERVICE_SETTINGS, validationException); - String modelId = ServiceUtils.extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + // Extract Google Vertex AI fields + String projectId = ServiceUtils.extractOptionalString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String location = ServiceUtils.extractOptionalString(map, LOCATION, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = ServiceUtils.extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + // Extract Google Model Garden fields + URI uri = ServiceUtils.extractOptionalUri(map, URL_SETTING_NAME, validationException); + URI streamingUri = ServiceUtils.extractOptionalUri(map, STREAMING_URL_SETTING_NAME, validationException); + GoogleModelGardenProvider provider = ServiceUtils.extractOptionalEnum( + map, + PROVIDER_SETTING_NAME, + ModelConfigurations.SERVICE_SETTINGS, + GoogleModelGardenProvider::fromString, + EnumSet.allOf(GoogleModelGardenProvider.class), + validationException + ); // Extract rate limit settings RateLimitSettings rateLimitSettings = RateLimitSettings.of( @@ -79,22 +152,78 @@ public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map taskSettings) { @@ -48,11 +63,19 @@ public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map { + return new GoogleModelGardenAnthropicChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings()); + } + case GOOGLE -> { + return new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings().thinkingConfig()); + } + case null, default -> throw new ElasticsearchException( + "Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() + ); + } + } + public void decorateWithAuth(HttpPost httpPost) { GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestEntity.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index 7e625530f197a..79df8a2e8b43c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.googlevertexai.request; +package org.elasticsearch.xpack.inference.services.googlevertexai.request.completion; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java index ba00810d47f4d..e89c5b1e84419 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingInferenceTestUtils.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.streaming; +import org.apache.commons.lang3.tuple.Pair; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.hamcrest.Matcher; import org.hamcrest.Matchers; @@ -14,6 +15,7 @@ import java.util.ArrayDeque; import java.util.Arrays; import java.util.Deque; +import java.util.List; public class StreamingInferenceTestUtils { @@ -23,6 +25,12 @@ public static Deque events(String... data) { return item; } + public static Deque events(List> data) { + var item = new ArrayDeque(); + data.forEach(pair -> item.offer(new ServerSentEvent(pair.getKey(), pair.getValue()))); + return item; + } + @SuppressWarnings("unchecked") public static Matcher> containsResults(String... results) { Matcher[] resultMatcher = Arrays.stream(results) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 1d97ef5f40c59..197f8375360fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -384,6 +384,35 @@ public void testCreateUri_ThrowsException_WithNullUrl() { expectThrows(NullPointerException.class, () -> createUri(null)); } + public void testExtractOptionalUri_ReturnsUri_WhenFieldIsValid() { + var validation = new ValidationException(); + Map map = Map.of("url", "www.elastic.co"); + var uri = ServiceUtils.extractOptionalUri(new HashMap<>(map), "url", validation); + + assertNotNull(uri); + assertTrue(validation.validationErrors().isEmpty()); + assertThat(uri.toString(), is("www.elastic.co")); + } + + public void testExtractOptionalUri_ReturnsNull_WhenFieldIsMissing() { + var validation = new ValidationException(); + Map map = Map.of("other", "www.elastic.co"); + var uri = ServiceUtils.extractOptionalUri(new HashMap<>(map), "url", validation); + + assertNull(uri); + assertTrue(validation.validationErrors().isEmpty()); + } + + public void testExtractOptionalUri_ReturnsNullAndAddsValidationError_WhenFieldIsInvalid() { + var validation = new ValidationException(); + Map map = Map.of("url", "^^"); + var uri = ServiceUtils.extractOptionalUri(new HashMap<>(map), "url", validation); + + assertNull(uri); + assertThat(validation.validationErrors().size(), is(1)); + assertThat(validation.validationErrors().get(0), containsString("[service_settings] Invalid url [^^] received for field [url]")); + } + public void testExtractRequiredSecureString_CreatesSecureString() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", "value")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..b903d4cd7b084 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionResponseHandlerTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AnthropicChatCompletionResponseHandlerTests extends ESTestCase { + private static final String INFERENCE_ID = "anthropic_inference_id"; + + private final AnthropicChatCompletionResponseHandler responseHandler = new AnthropicChatCompletionResponseHandler("chat_completion"); + + public void testFailValidation() throws IOException { + var responseJson = """ + { + "type": "error", + "error": { + "type": "not_found_error", + "message": "The requested resource could not be found." + }, + "request_id": "req_011CSHoEeqs5C35K2UUqR7Fy" + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(Strings.format(""" + {"error":{"code":"not_found","message":"Received an unsuccessful status code for request from inference entity id [anthropic_i\ + nference_id] status [404]. Error message: [{\\n \\"type\\": \\"error\\",\\n \\"error\\": {\\n \\"type\\": \\"not_found_er\ + ror\\",\\n \\"message\\": \\"The requested resource could not be found.\\"\\n },\\n \\"request_id\\": \\"req_011CSHoEeqs5\ + C35K2UUqR7Fy\\"\\n}\\n]","type":"anthropic_error"}}\ + """, INFERENCE_ID))); + } + + private static Request mockRequest() { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn(INFERENCE_ID); + when(request.isStreaming()).thenReturn(true); + return request; + } + + private static HttpResponse mockHttpResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String invalidResponseJson(String responseJson) throws IOException { + var exception = invalidResponse(responseJson); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } + + private Exception invalidResponse(String responseJson) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockHttpResponse(404), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionStreamingProcessorTests.java new file mode 100644 index 0000000000000..689532fd99046 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicChatCompletionStreamingProcessorTests.java @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.apache.commons.lang3.tuple.Pair; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.hamcrest.Matchers; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onError; +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onNext; +import static org.elasticsearch.xpack.inference.external.response.streaming.StreamingInferenceTestUtils.events; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class AnthropicChatCompletionStreamingProcessorTests extends ESTestCase { + + public void testParseSuccess() { + var item = events( + List.of( + Pair.of("message_start", """ + { + "type": "message_start", + "message": { + "id": "msg_vrtx_01F9nngkx9PojtBCkhj9xP2v", + "type": "message", + "role": "assistant", + "model": "claude-3-5-haiku-20241022", + "content": [], + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 393, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "output_tokens": 1 + } + } + } + """), + Pair.of("content_block_start", """ + {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}"""), + Pair.of("ping", """ + {"type": "ping"}"""), + Pair.of("content_block_delta", """ + {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"""), + Pair.of("content_block_delta", """ + {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"World"}}"""), + Pair.of("content_block_stop", """ + {"type":"content_block_stop","index":0}"""), + Pair.of("content_block_start", """ + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "tool_use", + "id": "toolu_vrtx_01GooUb1exnL7s8QrUgAQvQj", + "name": "get_weather", + "input": {} + } + } + """), + Pair.of("content_block_delta", """ + {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"Hello"}}"""), + Pair.of("content_block_delta", """ + {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"World"}}"""), + Pair.of("content_block_stop", """ + {"type":"content_block_stop","index":1}"""), + Pair.of("message_delta", """ + { + "type": "message_delta", + "delta": { + "stop_reason": "tool_use", + "stop_sequence": null + }, + "usage": { + "output_tokens": 99 + } + } + """), + Pair.of("message_stop", """ + {"type":"message_stop"}""") + ) + ); + + var response = onNext(new AnthropicChatCompletionStreamingProcessor((noOp1, noOp2) -> { + fail("This should not be called"); + return null; + }), item); + assertThat(response.chunks().size(), equalTo(8)); + { + assertMessageStartBlock(response); + } + { + assertContent(response, ""); + } + { + assertContent(response, "Hello"); + } + { + assertContent(response, "World"); + } + { + assertToolUseContentStartBlock(response); + } + { + assertToolUseArguments(response, "Hello"); + } + { + assertToolUseArguments(response, "World"); + } + { + assertMessageDeltaBlock(response); + } + } + + private static void assertMessageDeltaBlock(StreamingUnifiedChatCompletionResults.Results response) { + var chatCompletionChunk = response.chunks().remove(); + var choices = chatCompletionChunk.choices(); + assertThat(choices.size(), is(1)); + assertThat(choices.getFirst().index(), is(0)); + assertNull(choices.getFirst().delta().toolCalls()); + assertNull(choices.getFirst().delta().content()); + assertThat(choices.getFirst().finishReason(), is("tool_use")); + assertUsage(chatCompletionChunk.usage(), 99, 0, 99); + } + + private static void assertMessageStartBlock(StreamingUnifiedChatCompletionResults.Results response) { + var chatCompletionChunk = response.chunks().remove(); + assertThat(chatCompletionChunk.id(), is("msg_vrtx_01F9nngkx9PojtBCkhj9xP2v")); + assertThat(chatCompletionChunk.model(), is("claude-3-5-haiku-20241022")); + assertUsage(chatCompletionChunk.usage(), 1, 393, 394); + assertThat(chatCompletionChunk.choices().size(), is(1)); + var choice = chatCompletionChunk.choices().getFirst(); + assertThat(choice.index(), is(0)); + assertThat(choice.delta().role(), is("assistant")); + } + + private static void assertToolUseContentStartBlock(StreamingUnifiedChatCompletionResults.Results response) { + var choices = response.chunks().remove().choices(); + assertThat(choices.size(), is(1)); + assertThat(choices.getFirst().index(), is(1)); + var toolCalls = choices.getFirst().delta().toolCalls(); + assertThat(toolCalls.size(), is(1)); + assertThat(toolCalls.getFirst().index(), is(0)); + assertThat(toolCalls.getFirst().id(), is("toolu_vrtx_01GooUb1exnL7s8QrUgAQvQj")); + var function = toolCalls.getFirst().function(); + assertThat(function.arguments(), is("{}")); + assertThat(function.name(), is("get_weather")); + } + + private static void assertToolUseArguments(StreamingUnifiedChatCompletionResults.Results response, String arguments) { + var choices = response.chunks().remove().choices(); + assertThat(choices.size(), is(1)); + assertThat(choices.getFirst().index(), is(1)); + var toolCalls = choices.getFirst().delta().toolCalls(); + assertThat(toolCalls.size(), is(1)); + assertThat(toolCalls.getFirst().index(), is(0)); + assertNull(toolCalls.getFirst().id()); + var function = toolCalls.getFirst().function(); + assertThat(function.arguments(), Matchers.is(arguments)); + assertNull(function.name()); + } + + private static void assertContent(StreamingUnifiedChatCompletionResults.Results response, String content) { + var choices = response.chunks().remove().choices(); + assertThat(choices.size(), is(1)); + assertThat(choices.getFirst().index(), is(0)); + assertThat(choices.getFirst().delta().content(), Matchers.is(content)); + } + + private static void assertUsage( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage usage, + int completion, + int prompt, + int total + ) { + assertThat(usage.completionTokens(), is(completion)); + assertThat(usage.promptTokens(), is(prompt)); + assertThat(usage.totalTokens(), is(total)); + } + + public void testEmptyResultsRequestsMoreData() throws Exception { + var emptyDeque = new ArrayDeque(); + + var processor = new AnthropicChatCompletionStreamingProcessor((noOp1, noOp2) -> { + fail("This should not be called"); + return null; + }); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(emptyDeque); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testOnError() { + var expectedException = new RuntimeException("hello"); + + var processor = new AnthropicChatCompletionStreamingProcessor((noOp1, noOp2) -> { throw expectedException; }); + + assertThat(onError(processor, events(List.of(Pair.of("error", "error")))), sameInstance(expectedException)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index fdcd65c2cbe64..8724afa707914 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -47,6 +47,8 @@ import org.junit.Before; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -92,6 +94,9 @@ public void testParseRequestConfig_CreateGoogleVertexAiChatCompletionModel() thr var projectId = "project"; var location = "location"; var modelId = "model"; + var provider = GoogleModelGardenProvider.ANTHROPIC.name(); + var url = "https://non-streaming.url"; + var streamingUrl = "https://streaming.url"; var serviceAccountJson = """ { "some json" @@ -107,6 +112,11 @@ public void testParseRequestConfig_CreateGoogleVertexAiChatCompletionModel() thr assertThat(vertexAIModel.getServiceSettings().modelId(), is(modelId)); assertThat(vertexAIModel.getServiceSettings().location(), is(location)); assertThat(vertexAIModel.getServiceSettings().projectId(), is(projectId)); + + assertThat(vertexAIModel.getServiceSettings().provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertThat(vertexAIModel.getServiceSettings().uri(), is(new URI(url))); + assertThat(vertexAIModel.getServiceSettings().streamingUri(), is(new URI(streamingUrl))); + assertThat(vertexAIModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); assertThat(vertexAIModel.getConfigurations().getTaskType(), equalTo(CHAT_COMPLETION)); assertThat(vertexAIModel.getServiceSettings().rateLimitSettings().requestsPerTimeUnit(), equalTo(1000L)); @@ -125,7 +135,13 @@ public void testParseRequestConfig_CreateGoogleVertexAiChatCompletionModel() thr GoogleVertexAiServiceFields.LOCATION, location, GoogleVertexAiServiceFields.PROJECT_ID, - projectId + projectId, + GoogleVertexAiServiceFields.PROVIDER_SETTING_NAME, + provider, + GoogleVertexAiServiceFields.URL_SETTING_NAME, + url, + GoogleVertexAiServiceFields.STREAMING_URL_SETTING_NAME, + streamingUrl ) ), getTaskSettingsMapEmpty(), @@ -485,10 +501,13 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM } } - public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatCompletionModel() throws IOException { + public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatCompletionModel() throws IOException, URISyntaxException { var projectId = "project"; var location = "location"; var modelId = "model"; + var provider = GoogleModelGardenProvider.ANTHROPIC.name(); + var url = "https://non-streaming.url"; + var streamingUrl = "https://streaming.url"; var autoTruncate = true; var serviceAccountJson = """ { @@ -505,7 +524,13 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatComplet GoogleVertexAiServiceFields.LOCATION, location, GoogleVertexAiServiceFields.PROJECT_ID, - projectId + projectId, + GoogleVertexAiServiceFields.PROVIDER_SETTING_NAME, + provider, + GoogleVertexAiServiceFields.URL_SETTING_NAME, + url, + GoogleVertexAiServiceFields.STREAMING_URL_SETTING_NAME, + streamingUrl ) ), getTaskSettingsMap(autoTruncate, InputType.INGEST), @@ -529,6 +554,10 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatComplet assertThat(chatCompletionModel.getConfigurations().getTaskType(), equalTo(CHAT_COMPLETION)); assertThat(chatCompletionModel.getServiceSettings().rateLimitSettings().requestsPerTimeUnit(), equalTo(1000L)); assertThat(chatCompletionModel.getServiceSettings().rateLimitSettings().timeUnit(), equalTo(MINUTES)); + + assertThat(chatCompletionModel.getServiceSettings().provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertThat(chatCompletionModel.getServiceSettings().uri(), is(new URI(url))); + assertThat(chatCompletionModel.getServiceSettings().streamingUri(), is(new URI(streamingUrl))); } } @@ -966,8 +995,6 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si } } - // testInfer tested via end-to-end notebook tests in AppEx repo - @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { try (var service = createGoogleVertexAiService()) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java index 0e720f60dfe2d..4d1c590390648 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java @@ -20,26 +20,32 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleModelGardenProvider; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; -import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.List; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER; import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.USER_ROLE; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; @@ -68,71 +74,178 @@ public void shutdown() throws IOException { webServer.close(); } - private static UnifiedChatInput createUnifiedChatInput(List messages, String role) { + private static UnifiedChatInput createUnifiedChatInput(List messages) { boolean stream = true; - return new UnifiedChatInput(messages, role, stream); + return new UnifiedChatInput(messages, "user", stream); } // Successful case would typically be tested via end-to-end notebook tests in AppEx repo - public void testExecute_ThrowsElasticsearchException() { + public void testExecute_ThrowsElasticsearchExceptionGoogleVertexAi() { + testExecute_ThrowsElasticsearchException( + "us-central1", + "test-project-id", + "chat-bison", + null, + null, + GoogleVertexAiService.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER + ); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledGoogleVertexAi() { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled( + "us-central1", + "test-project-id", + "chat-bison", + null, + null, + GoogleVertexAiService.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER + ); + } + + public void testExecute_ThrowsExceptionGoogleVertexAi() { + testExecute_ThrowsException("us-central1", "test-project-id", "chat-bison", null, null, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER); + } + + public void testExecute_ThrowsElasticsearchExceptionAnthropic() throws URISyntaxException { + testExecute_ThrowsElasticsearchException( + null, + null, + null, + GoogleModelGardenProvider.ANTHROPIC, + new URI("http://localhost:9200"), + GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER + ); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalledAnthropic() throws URISyntaxException { + testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled( + null, + null, + null, + GoogleModelGardenProvider.ANTHROPIC, + new URI("http://localhost:9200"), + GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER + ); + } + + public void testExecute_ThrowsExceptionAnthropic() throws URISyntaxException { + testExecute_ThrowsException( + null, + null, + null, + GoogleModelGardenProvider.ANTHROPIC, + new URI("http://localhost:9200"), + GoogleVertexAiActionCreator.GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER + ); + } + + private void testExecute_ThrowsException( + String location, + String projectId, + String actualModelId, + GoogleModelGardenProvider provider, + URI uri, + ResponseHandler handler + ) { var sender = mock(Sender.class); - doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); - var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + var action = createAction(location, projectId, actualModelId, sender, provider, uri, handler); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(createUnifiedChatInput(List.of("test query")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(thrownException.getMessage(), is("failed")); + assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); } - public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + private void testExecute_ThrowsElasticsearchException( + String location, + String projectId, + String actualModelId, + GoogleModelGardenProvider googleModelGardenProvider, + URI uri, + ResponseHandler googleModelGardenAnthropicCompletionHandler + ) { var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); - doAnswer(invocation -> { - ActionListener listenerArg = invocation.getArgument(3); - listenerArg.onFailure(new IllegalStateException("failed")); - return Void.TYPE; - }).when(sender).send(any(), any(), any(), any()); - - var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + var action = createAction( + location, + projectId, + actualModelId, + sender, + googleModelGardenProvider, + uri, + googleModelGardenAnthropicCompletionHandler + ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(createUnifiedChatInput(List.of("test query")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); + assertThat(thrownException.getMessage(), is("failed")); } - public void testExecute_ThrowsException() { + private void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled( + String location, + String projectId, + String actualModelId, + GoogleModelGardenProvider googleModelGardenProvider, + URI uri, + ResponseHandler googleModelGardenAnthropicCompletionHandler + ) { var sender = mock(Sender.class); - doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); - var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + doAnswer(invocation -> { + ActionListener listenerArg = invocation.getArgument(3); + listenerArg.onFailure(new IllegalStateException("failed")); + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction( + location, + projectId, + actualModelId, + sender, + googleModelGardenProvider, + uri, + googleModelGardenAnthropicCompletionHandler + ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(createUnifiedChatInput(List.of("test query")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); } - private ExecutableAction createAction(String location, String projectId, String actualModelId, Sender sender) { + private ExecutableAction createAction( + String location, + String projectId, + String actualModelId, + Sender sender, + GoogleModelGardenProvider provider, + URI uri, + ResponseHandler handler + ) { var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( projectId, location, actualModelId, "api-key", new RateLimitSettings(100), - new ThinkingConfig(256) + new ThinkingConfig(256), + provider, + uri, + 123 ); var manager = new GenericRequestManager<>( threadPool, model, - COMPLETION_HANDLER, + handler, inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), ChatCompletionInput.class ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java index 13a60670f1bdf..57584ae617953 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java @@ -7,13 +7,24 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.completion; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleModelGardenProvider; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.Matchers; +import java.net.URI; +import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils.ML_INFERENCE_GOOGLE_MODEL_GARDEN_ADDED; +import static org.hamcrest.Matchers.is; + public class GoogleVertexAIChatCompletionServiceSettingsTests extends InferenceSettingsTestCase< GoogleVertexAiChatCompletionServiceSettings> { @Override @@ -27,12 +38,180 @@ protected GoogleVertexAiChatCompletionServiceSettings fromMutableMap(Map settingsMap) { + GoogleVertexAiChatCompletionServiceSettings settings = GoogleVertexAiChatCompletionServiceSettings.fromMap( + new HashMap<>(settingsMap), + ConfigurationParseContext.REQUEST + ); + assertThat(settings.projectId(), is("my-project")); + assertThat(settings.location(), is("us-central1")); + assertThat(settings.modelId(), is("my-model")); + assertThat(settings.provider(), is(GoogleModelGardenProvider.GOOGLE)); + assertNull(settings.streamingUri()); + assertNull(settings.uri()); + assertThat(settings.rateLimitSettings(), is(new RateLimitSettings(1000))); + } + + public void testFromMapGoogleVertexAi_UrlPresent_Failure() { + testValidationFailure(Map.of("url", "url", "project_id", "my-project", "location", "us-central1", "model_id", "my-model"), """ + Validation Failed: 1: 'provider' is either GOOGLE or null. For Google Vertex AI models 'uri' and 'streaming_uri' must \ + not be provided. Remove 'url' and 'streaming_url' fields. Provided values: uri=url, streaming_uri=null;"""); + } + + public void testFromMapGoogleVertexAi_StreamingUrlPresent_Failure() { + testValidationFailure( + Map.of("streaming_url", "streaming_url", "project_id", "my-project", "location", "us-central1", "model_id", "my-model"), + """ + Validation Failed: 1: 'provider' is either GOOGLE or null. For Google Vertex AI models 'uri' and 'streaming_uri' must \ + not be provided. Remove 'url' and 'streaming_url' fields. Provided values: uri=null, streaming_uri=streaming_url;""" + ); + } + + public void testFromMapGoogleModelGarden_Success() { + GoogleVertexAiChatCompletionServiceSettings settings = GoogleVertexAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of("url", "url", "streaming_url", "streaming_url", "provider", "anthropic")), + ConfigurationParseContext.REQUEST + ); + assertNull(settings.projectId()); + assertNull(settings.location()); + assertNull(settings.modelId()); + assertThat(settings.provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertThat(settings.uri().toString(), is("url")); + assertThat(settings.streamingUri().toString(), is("streaming_url")); + assertThat(settings.rateLimitSettings(), is(new RateLimitSettings(1000))); + } + + public void testFromMapGoogleModelGarden_NoProvider_Failure() { + testValidationFailure(Map.of("url", "url", "streaming_url", "streaming_url"), """ + Validation Failed: 1: 'provider' is either GOOGLE or null. For Google Vertex AI models 'uri' and 'streaming_uri' must \ + not be provided. Remove 'url' and 'streaming_url' fields. Provided values: uri=url, streaming_uri=streaming_url;"""); + } + + public void testFromMapGoogleModelGarden_GoogleProvider_Failure() { + testValidationFailure(Map.of("url", "url", "streaming_url", "streaming_url", "provider", "google"), """ + Validation Failed: 1: 'provider' is either GOOGLE or null. For Google Vertex AI models 'uri' and 'streaming_uri' must \ + not be provided. Remove 'url' and 'streaming_url' fields. Provided values: uri=url, streaming_uri=streaming_url;"""); + } + + public void testFromMapGoogleModelGarden_NoUrl_Success() { + GoogleVertexAiChatCompletionServiceSettings settings = GoogleVertexAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of("streaming_url", "streaming_url", "provider", "anthropic")), + ConfigurationParseContext.REQUEST + ); + assertNull(settings.projectId()); + assertNull(settings.location()); + assertNull(settings.modelId()); + assertThat(settings.provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertNull(settings.uri()); + assertThat(settings.streamingUri().toString(), is("streaming_url")); + assertThat(settings.rateLimitSettings(), is(new RateLimitSettings(1000))); + } + + public void testFromMapGoogleModelGarden_NoStreamingUrl_Success() { + GoogleVertexAiChatCompletionServiceSettings settings = GoogleVertexAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of("url", "url", "provider", "anthropic")), + ConfigurationParseContext.REQUEST + ); + assertNull(settings.projectId()); + assertNull(settings.location()); + assertNull(settings.modelId()); + assertThat(settings.provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertNull(settings.streamingUri()); + assertThat(settings.uri().toString(), is("url")); + assertThat(settings.rateLimitSettings(), is(new RateLimitSettings(1000))); + } + + public void testFromMapGoogleModelGarden_NoUrls_Failure() { + testValidationFailure(Map.of("provider", "anthropic"), """ + Validation Failed: 1: Google Model Garden provider=anthropic selected. Either 'uri' or 'streaming_uri' must be provided;"""); + } + + public void testFromMapGoogleVertexAi_NoModel_Failure() { + testValidationFailure(Map.of("project_id", "my-project", "location", "us-central1"), """ + Validation Failed: 1: For Google Vertex AI models, you must provide 'location', 'project_id', and 'model_id'. \ + Provided values: location=us-central1, project_id=my-project, model_id=null;"""); + } + + public void testFromMapGoogleVertexAi_NoLocation_Failure() { + testValidationFailure(Map.of("project_id", "my-project", "model_id", "my-model"), """ + Validation Failed: 1: For Google Vertex AI models, you must provide 'location', 'project_id', and 'model_id'. \ + Provided values: location=null, project_id=my-project, model_id=my-model;"""); + } + + public void testFromMapGoogleVertexAi_NoProject_Failure() { + testValidationFailure(Map.of("location", "us-central1", "model_id", "my-model"), """ + Validation Failed: 1: For Google Vertex AI models, you must provide 'location', 'project_id', and 'model_id'. \ + Provided values: location=us-central1, project_id=null, model_id=my-model;"""); + } + + private static void testValidationFailure(Map taskSettingsMap, String expectedErrorMessage) { + var thrownException = expectThrows( + ValidationException.class, + () -> GoogleVertexAiChatCompletionServiceSettings.fromMap(new HashMap<>(taskSettingsMap), ConfigurationParseContext.REQUEST) + ); + assertThat(thrownException.getMessage(), Matchers.is(expectedErrorMessage)); + } + @Override protected GoogleVertexAiChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + private static GoogleVertexAiChatCompletionServiceSettings createRandom() { + return randomBoolean() ? createRandomWithGoogleVertexAiSettings() : createRandomWithGoogleModelGardenSettings(); + } + + private static GoogleVertexAiChatCompletionServiceSettings createRandomWithGoogleVertexAiSettings() { return new GoogleVertexAiChatCompletionServiceSettings( randomString(), randomString(), randomString(), + null, + null, + randomFrom(GoogleModelGardenProvider.GOOGLE, null), + new RateLimitSettings(randomIntBetween(1, 1000)) + ); + } + + private static GoogleVertexAiChatCompletionServiceSettings createRandomWithGoogleModelGardenSettings() { + URI optionalUri = createOptionalUri(randomOptionalString()); + return new GoogleVertexAiChatCompletionServiceSettings( + randomOptionalString(), + randomOptionalString(), + randomOptionalString(), + optionalUri, + optionalUri == null ? createUri(randomString()) : createOptionalUri(randomOptionalString()), + randomFrom(GoogleModelGardenProvider.ANTHROPIC), new RateLimitSettings(randomIntBetween(1, 1000)) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java index cb9fd803047bc..b74ff256659a5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleModelGardenProvider; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -43,7 +44,10 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT, - EMPTY_THINKING_CONFIG + EMPTY_THINKING_CONFIG, + null, + null, + null ); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), @@ -75,7 +79,10 @@ public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenReques DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT, - EMPTY_THINKING_CONFIG + EMPTY_THINKING_CONFIG, + null, + null, + 123 ); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), @@ -120,11 +127,14 @@ public void testOf_overridesTaskSettings_whenPresent() { DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT, - new ThinkingConfig(123) + new ThinkingConfig(123), + null, + null, + 123 ); int newThinkingBudget = 456; Map taskSettings = new HashMap<>( - Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, newThinkingBudget))) + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, newThinkingBudget)), "max_tokens", 456) ); var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings); @@ -135,6 +145,7 @@ public void testOf_overridesTaskSettings_whenPresent() { assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(new ThinkingConfig(newThinkingBudget))); + assertThat(overriddenModel.getTaskSettings().maxTokens(), is(456)); } public void testOf_doesNotOverrideTaskSettings_whenNotPresent() { @@ -145,7 +156,10 @@ public void testOf_doesNotOverrideTaskSettings_whenNotPresent() { DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT, - originalThinkingConfig + originalThinkingConfig, + null, + null, + 123 ); Map taskSettings = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>())); var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings); @@ -157,6 +171,61 @@ public void testOf_doesNotOverrideTaskSettings_whenNotPresent() { assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(originalThinkingConfig)); + assertThat(overriddenModel.getTaskSettings().maxTokens(), is(123)); + } + + public void testModelCreationForAnthropicBothUrls() throws URISyntaxException { + var uri = new URI("http://example.com"); + var streamingUri = new URI("http://example-streaming.com"); + testModelCreationForAnthropic(uri, streamingUri, uri, streamingUri); + } + + public void testModelCreationForAnthropicOnlyNonStreamingUrl() throws URISyntaxException { + var uri = new URI("http://example.com"); + testModelCreationForAnthropic(uri, null, uri, uri); + } + + public void testModelCreationForAnthropicOnlyStreamingUrl() throws URISyntaxException { + var streamingUri = new URI("http://example-streaming.com"); + testModelCreationForAnthropic(null, streamingUri, streamingUri, streamingUri); + } + + private static void testModelCreationForAnthropic(URI uri, URI streamingUri, URI expectedNonStreamingUri, URI expectedStreamingUri) { + var model = createAnthropicChatCompletionModel( + DEFAULT_API_KEY, + DEFAULT_RATE_LIMIT, + EMPTY_THINKING_CONFIG, + GoogleModelGardenProvider.ANTHROPIC, + uri, + streamingUri, + 123 + ); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + assertThat(overriddenModel, not(sameInstance(model))); + assertNull(overriddenModel.getServiceSettings().projectId()); + assertNull(overriddenModel.getServiceSettings().location()); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getServiceSettings().uri(), is(uri)); + assertThat(overriddenModel.getServiceSettings().streamingUri(), is(streamingUri)); + assertThat(overriddenModel.getServiceSettings().provider(), is(GoogleModelGardenProvider.ANTHROPIC)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(EMPTY_THINKING_CONFIG)); + assertThat(overriddenModel.getTaskSettings().maxTokens(), is(123)); + assertThat(overriddenModel.nonStreamingUri(), is(expectedNonStreamingUri)); + assertThat(overriddenModel.streamingURI(), is(expectedStreamingUri)); } public static GoogleVertexAiChatCompletionModel createCompletionModel( @@ -165,14 +234,36 @@ public static GoogleVertexAiChatCompletionModel createCompletionModel( String modelId, String apiKey, RateLimitSettings rateLimitSettings, - ThinkingConfig thinkingConfig + ThinkingConfig thinkingConfig, + GoogleModelGardenProvider provider, + URI uri, + Integer maxTokens + ) { + return new GoogleVertexAiChatCompletionModel( + "google-vertex-ai-chat-test-id", + TaskType.CHAT_COMPLETION, + "google_vertex_ai", + new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, uri, uri, provider, rateLimitSettings), + new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig, maxTokens), + new GoogleVertexAiSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static GoogleVertexAiChatCompletionModel createAnthropicChatCompletionModel( + String apiKey, + RateLimitSettings rateLimitSettings, + ThinkingConfig thinkingConfig, + GoogleModelGardenProvider provider, + URI uri, + URI streamingUri, + int maxTokens ) { return new GoogleVertexAiChatCompletionModel( "google-vertex-ai-chat-test-id", TaskType.CHAT_COMPLETION, "google_vertex_ai", - new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings), - new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig), + new GoogleVertexAiChatCompletionServiceSettings(null, null, null, uri, streamingUri, provider, rateLimitSettings), + new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig, maxTokens), new GoogleVertexAiSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java index cc567b24fe773..1d02e1f41cfcb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.completion; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; @@ -16,12 +17,13 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_BUDGET_FIELD; import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_CONFIG_FIELD; +import static org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils.ML_INFERENCE_GOOGLE_MODEL_GARDEN_ADDED; import static org.hamcrest.Matchers.is; public class GoogleVertexAiChatCompletionTaskSettingsTests extends InferenceSettingsTestCase { public void testUpdatedTaskSettings_updatesTaskSettingsWhenDifferent() { - var initialSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123)); + var initialSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123), 123); int updatedThinkingBudget = 456; Map newSettingsMap = new HashMap<>( Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, updatedThinkingBudget))) @@ -33,7 +35,7 @@ public void testUpdatedTaskSettings_updatesTaskSettingsWhenDifferent() { } public void testUpdatedTaskSettings_doesNotUpdateTaskSettingsWhenNewSettingsAreEmpty() { - var initialSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123)); + var initialSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123), 123); Map emptySettingsMap = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>())); GoogleVertexAiChatCompletionTaskSettings updatedSettings = (GoogleVertexAiChatCompletionTaskSettings) initialSettings @@ -43,15 +45,17 @@ public void testUpdatedTaskSettings_doesNotUpdateTaskSettingsWhenNewSettingsAreE public void testFromMap_returnsSettings() { int thinkingBudget = 256; + int maxTokens = 256; Map settings = new HashMap<>( - Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, thinkingBudget))) + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, thinkingBudget)), "max_tokens", maxTokens) ); var result = GoogleVertexAiChatCompletionTaskSettings.fromMap(settings); assertThat(result.thinkingConfig().getThinkingBudget(), is(thinkingBudget)); + assertThat(result.maxTokens(), is(maxTokens)); } - public void testFromMap_throwsWhenValidationErrorEncountered() { + public void testFromMap_throwsWhenValidationErrorEncounteredThinkingConfig() { Map settings = new HashMap<>( Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, "not_an_int"))) ); @@ -59,30 +63,46 @@ public void testFromMap_throwsWhenValidationErrorEncountered() { expectThrows(ValidationException.class, () -> GoogleVertexAiChatCompletionTaskSettings.fromMap(settings)); } + public void testFromMap_throwsWhenValidationErrorEncounteredMaxTokens() { + Map settings = new HashMap<>(Map.of("max_tokens", "not_an_int")); + + expectThrows(ValidationException.class, () -> GoogleVertexAiChatCompletionTaskSettings.fromMap(settings)); + } + public void testOf_overridesOriginalSettings_whenNewSettingsPresent() { // Confirm we can overwrite empty settings var originalSettings = new GoogleVertexAiChatCompletionTaskSettings(); int newThinkingBudget = 123; - var newSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(newThinkingBudget)); + int newMaxTokens = 123; + var newSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(newThinkingBudget), newMaxTokens); var updatedSettings = GoogleVertexAiChatCompletionTaskSettings.of(originalSettings, newSettings); assertThat(updatedSettings.thinkingConfig().getThinkingBudget(), is(newThinkingBudget)); + assertThat(updatedSettings.maxTokens(), is(newMaxTokens)); + // Confirm we can overwrite existing settings int secondNewThinkingBudget = 456; - var secondNewSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(secondNewThinkingBudget)); + int secondNewMaxTokens = 456; + var secondNewSettings = new GoogleVertexAiChatCompletionTaskSettings( + new ThinkingConfig(secondNewThinkingBudget), + secondNewMaxTokens + ); var secondUpdatedSettings = GoogleVertexAiChatCompletionTaskSettings.of(updatedSettings, secondNewSettings); assertThat(secondUpdatedSettings.thinkingConfig().getThinkingBudget(), is(secondNewThinkingBudget)); + assertThat(secondUpdatedSettings.maxTokens(), is(secondNewThinkingBudget)); } - public void testOf_doesNotOverrideOriginalSettings_whenNewSettingsNotPresent() { + public void testOf_doesNotOverrideOriginalThinkingSettings_whenNewSettingsNotPresent() { int originalThinkingBudget = 123; - var originalSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(originalThinkingBudget)); + int originalMaxTokens = 123; + var originalSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(originalThinkingBudget), originalMaxTokens); var emptySettings = new GoogleVertexAiChatCompletionTaskSettings(); var updatedSettings = GoogleVertexAiChatCompletionTaskSettings.of(originalSettings, emptySettings); assertThat(updatedSettings.thinkingConfig().getThinkingBudget(), is(originalThinkingBudget)); + assertThat(updatedSettings.maxTokens(), is(123)); } @Override @@ -90,6 +110,18 @@ protected GoogleVertexAiChatCompletionTaskSettings fromMutableMap(Map instanceReader() { return GoogleVertexAiChatCompletionTaskSettings::new; @@ -97,6 +129,6 @@ protected Writeable.Reader instanceRea @Override protected GoogleVertexAiChatCompletionTaskSettings createTestInstance() { - return new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(randomInt())); + return new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(randomInt()), randomNonNegativeIntOrNull()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleModelGardenAnthropicChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleModelGardenAnthropicChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..3efb64885416f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleModelGardenAnthropicChatCompletionRequestEntityTests.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class GoogleModelGardenAnthropicChatCompletionRequestEntityTests extends ESTestCase { + + public void testModelUserFieldsSerializationStreamingWithTemperatureAndTopK() throws IOException { + XContentBuilder builder = setUpXContentBuilder(0.2F, 0.2F, 100L, true, GoogleVertexAiChatCompletionTaskSettings.EMPTY_SETTINGS); + String expectedJson = """ + { + "anthropic_version": "vertex-2023-10-16", + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "temperature": 0.2, + "tool_choice": { + "type": "auto" + }, + "tools": [{ + "name": "name", + "description": "description", + "input_schema": { + "parameterName": "parameterValue" + } + } + ], + "top_p": 0.2, + "stream": true, + "max_tokens": 100 + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + + public void testModelUserFieldsSerializationNonStreamDefaultMaxTokens() throws IOException { + XContentBuilder builder = setUpXContentBuilder(null, null, null, false, GoogleVertexAiChatCompletionTaskSettings.EMPTY_SETTINGS); + String expectedJson = """ + { + "anthropic_version": "vertex-2023-10-16", + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "tool_choice": { + "type": "auto" + }, + "tools": [{ + "name": "name", + "description": "description", + "input_schema": { + "parameterName": "parameterValue" + } + } + ], + "stream": false, + "max_tokens": 1024 + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + + public void testModelUserFieldsSerializationNonStreamWithMaxTokensFromTaskSettings() throws IOException { + XContentBuilder builder = setUpXContentBuilder( + null, + null, + null, + false, + new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123), 123) + ); + String expectedJson = """ + { + "anthropic_version": "vertex-2023-10-16", + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "tool_choice": { + "type": "auto" + }, + "tools": [{ + "name": "name", + "description": "description", + "input_schema": { + "parameterName": "parameterValue" + } + } + ], + "stream": false, + "max_tokens": 123 + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + + private static XContentBuilder setUpXContentBuilder( + Float topP, + Float temperature, + Long maxCompletionTokens, + boolean stream, + GoogleVertexAiChatCompletionTaskSettings taskSettings + ) throws IOException { + var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), "user", null, null); + var messageList = new ArrayList(); + messageList.add(message); + var unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, + maxCompletionTokens, + null, + temperature, + new UnifiedCompletionRequest.ToolChoiceObject("auto", new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("name")), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField("description", "name", Map.of("parameterName", "parameterValue"), null) + ) + ), + topP + ); + var unifiedChatInput = new UnifiedChatInput(unifiedRequest, stream); + var entity = new GoogleModelGardenAnthropicChatCompletionRequestEntity(unifiedChatInput.getRequest(), stream, taskSettings); + var builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java similarity index 99% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index d33fba0c31806..006bf397ce37c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.googlevertexai.request; +package org.elasticsearch.xpack.inference.services.googlevertexai.request.completion; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ParsingException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestTests.java similarity index 74% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestTests.java index aa4dff03a962a..9b8427e03ba4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/completion/GoogleVertexAiUnifiedChatCompletionRequestTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.googlevertexai.request; +package org.elasticsearch.xpack.inference.services.googlevertexai.request.completion; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; @@ -14,9 +14,11 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleModelGardenProvider; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiRequest; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; @@ -35,13 +37,44 @@ public class GoogleVertexAiUnifiedChatCompletionRequestTests extends ESTestCase private static final String AUTH_HEADER_VALUE = "Bearer foo"; public void testCreateRequest_Default() throws IOException { + var requestMap = testCreateRequest(null); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap, + equalTo(Map.of("contents", List.of(Map.of("role", "user", "parts", List.of(Map.of("text", "Hello Gemini!")))))) + ); + + } + + public void testCreateRequest_Anthropic() throws IOException { + var requestMap = testCreateRequest(GoogleModelGardenProvider.ANTHROPIC); + assertThat(requestMap, aMapWithSize(4)); + assertThat( + requestMap, + equalTo( + Map.of( + "stream", + true, + "max_tokens", + 1024, + "messages", + List.of(Map.of("role", "user", "content", "Hello Gemini!")), + "anthropic_version", + "vertex-2023-10-16" + ) + ) + ); + + } + + private static Map testCreateRequest(GoogleModelGardenProvider googleModelGardenProvider) throws IOException { var modelId = "gemini-pro"; var projectId = "test-project"; var location = "us-central1"; var messages = List.of("Hello Gemini!"); - var request = createRequest(projectId, location, modelId, messages, null, null, null); + var request = createRequest(projectId, location, modelId, messages, null, null, null, googleModelGardenProvider); var httpRequest = request.createHttpRequest(); var httpPost = (HttpPost) httpRequest.httpRequestBase(); @@ -59,13 +92,7 @@ public void testCreateRequest_Default() throws IOException { assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); - var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(1)); - assertThat( - requestMap, - equalTo(Map.of("contents", List.of(Map.of("role", "user", "parts", List.of(Map.of("text", messages.getFirst())))))) - ); - + return entityAsMap(httpPost.getEntity().getContent()); } public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( @@ -75,7 +102,8 @@ public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( List messages, @Nullable String apiKey, @Nullable RateLimitSettings rateLimitSettings, - @Nullable ThinkingConfig thinkingConfig + @Nullable ThinkingConfig thinkingConfig, + @Nullable GoogleModelGardenProvider provider ) { var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( projectId, @@ -83,7 +111,10 @@ public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( modelId, Objects.requireNonNullElse(apiKey, "default-api-key"), Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)), - thinkingConfig + thinkingConfig, + provider, + null, + null ); var unifiedChatInput = new UnifiedChatInput(messages, "user", true);