diff --git a/docs/changelog/128105.yaml b/docs/changelog/128105.yaml new file mode 100644 index 0000000000000..2dd6b55f54d24 --- /dev/null +++ b/docs/changelog/128105.yaml @@ -0,0 +1,5 @@ +pr: 128105 +summary: "Adding Google VertexAI chat completion integration" +area: Inference +type: enhancement +issues: [ ] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 163121f7ac498..ca42706b7caea 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -227,6 +227,7 @@ static TransportVersion def(int id) { public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35); public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36); public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37); + public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index ff9bc83f741f1..9ed1b2a642f4c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -151,14 +151,22 @@ public void testGetServicesWithCompletionTaskType() throws IOException { public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(6)); + assertThat(services.size(), equalTo(7)); var providers = providers(services); assertThat( providers, containsInAnyOrder( - List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray() + List.of( + "deepseek", + "elastic", + "openai", + "streaming_completion_test_service", + "hugging_face", + "amazon_sagemaker", + "googlevertexai" + ).toArray() ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 9d730daea6f46..900b017d81dd2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -73,6 +73,7 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings; @@ -453,6 +454,15 @@ private static void addGoogleVertexAiNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java index ea01c309253f4..60cd2faa7155b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java @@ -7,19 +7,20 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.util.Map; import java.util.Objects; -public abstract class GoogleVertexAiModel extends Model { +public abstract class GoogleVertexAiModel extends RateLimitGroupingModel { private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings; @@ -58,4 +59,18 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() { public URI uri() { return uri; } + + @Override + public int rateLimitGroupingHash() { + // In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that. + // API Key does not affect the quota + // https://ai.google.dev/gemini-api/docs/rate-limits + // https://cloud.google.com/vertex-ai/docs/quotas + return Objects.hash(uri); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings().rateLimitSettings(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java index 1349a65ce4fef..9adefd19ef6d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java @@ -9,11 +9,14 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity; +import java.util.function.Function; + import static org.elasticsearch.core.Strings.format; public class GoogleVertexAiResponseHandler extends BaseResponseHandler { @@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse); } + public GoogleVertexAiResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction, + boolean canHandleStreamingResponses + ) { + super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses); + } + @Override protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { int statusCode = result.response().getStatusLine().getStatusCode(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java index 9a39e200368cf..1abf1db642932 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -124,9 +124,8 @@ public static Map get() { var configurationMap = new HashMap(); configurationMap.put( SERVICE_ACCOUNT_JSON, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription( - "API Key for the provider you're connecting to." - ) + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION)) + .setDescription("API Key for the provider you're connecting to.") .setLabel("Credentials JSON") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index b849ed2e5cc9c..dc91e01322e6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -29,7 +29,10 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -38,8 +41,10 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -47,7 +52,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -55,17 +62,21 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; +import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX; public class GoogleVertexAiService extends SenderService { public static final String NAME = "googlevertexai"; private static final String SERVICE_NAME = "Google Vertex AI"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + TaskType.CHAT_COMPLETION + ); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -76,6 +87,15 @@ public class GoogleVertexAiService extends SenderService { InputType.INTERNAL_SEARCH ); + private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( + "Google VertexAI chat completion" + ); + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION); + } + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } @@ -220,7 +240,24 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof GoogleVertexAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model; + var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest()); + + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + updatedChatCompletionModel, + COMPLETION_HANDLER, + (unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel), + UnifiedChatInput.class + ); + + var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + action.execute(inputs, timeout, listener); } @Override @@ -320,6 +357,17 @@ private static GoogleVertexAiModel createModel( secretSettings, context ); + + case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -348,7 +396,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( LOCATION, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( "Please provide the GCP region where the Vertex AI API(s) is enabled. " + "For more information, refer to the {geminiVertexAIDocs}." ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..8c355c9f67f18 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.nio.charset.StandardCharsets; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.Flow; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { + + private static final String ERROR_FIELD = "error"; + private static final String ERROR_CODE_FIELD = "code"; + private static final String ERROR_MESSAGE_FIELD = "message"; + private static final String ERROR_STATUS_FIELD = "status"; + + private static final ResponseParser noopParseFunction = (a, b) -> null; + + public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) { + super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + assert request.isStreaming() : "GoogleVertexAiUnifiedChatCompletionResponseHandler only supports streaming requests"; + + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var googleVertexAiProcessor = new GoogleVertexAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(googleVertexAiProcessor); + return new StreamingUnifiedChatCompletionResults(googleVertexAiProcessor); + } + + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + + return errorResponse instanceof GoogleVertexAiErrorResponse vertexAIErrorResponse + ? new UnifiedChatCompletionException( + restStatus, + errorMessage, + vertexAIErrorResponse.status(), + String.valueOf(vertexAIErrorResponse.code()), + null + ) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + restStatus.name().toLowerCase(Locale.ROOT) + ); + } + + private static Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = GoogleVertexAiErrorResponse.fromString(message); + if (errorResponse instanceof GoogleVertexAiErrorResponse gver) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + gver.status(), + String.valueOf(gver.code()), + null + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + "stream_error" + ); + } + } + + private static class GoogleVertexAiErrorResponse extends ErrorResponse { + private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class); + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + "google_vertex_ai_error_wrapper", + true, + args -> Optional.ofNullable((GoogleVertexAiErrorResponse) args[0]) + ); + + private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( + "google_vertex_ai_error_body", + true, + args -> new GoogleVertexAiErrorResponse((Integer) args[0], (String) args[1], (String) args[2]) + ); + + static { + ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_CODE_FIELD)); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ERROR_MESSAGE_FIELD)); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_STATUS_FIELD)); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField(ERROR_FIELD) + ); + } + + static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + var resultAsString = new String(response.body(), StandardCharsets.UTF_8); + return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", resultAsString)); + } + } + + static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", response)); + } + } + + private final int code; + @Nullable + private final String status; + + GoogleVertexAiErrorResponse(Integer code, String errorMessage, @Nullable String status) { + super(Objects.requireNonNull(errorMessage)); + this.code = code == null ? 0 : code; + this.status = status; + } + + public int code() { + return code; + } + + @Nullable + public String status() { + return status != null ? status : "google_vertex_ai_error"; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java new file mode 100644 index 0000000000000..48bcc7845d657 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java @@ -0,0 +1,350 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class GoogleVertexAiUnifiedStreamingProcessor extends DelegatingProcessor< + Deque, + StreamingUnifiedChatCompletionResults.Results> { + + private static final Logger logger = LogManager.getLogger(GoogleVertexAiUnifiedStreamingProcessor.class); + + private static final String CANDIDATES_FIELD = "candidates"; + private static final String CONTENT_FIELD = "content"; + private static final String ROLE_FIELD = "role"; + private static final String PARTS_FIELD = "parts"; + private static final String TEXT_FIELD = "text"; + private static final String FINISH_REASON_FIELD = "finishReason"; + private static final String INDEX_FIELD = "index"; + private static final String USAGE_METADATA_FIELD = "usageMetadata"; + private static final String PROMPT_TOKEN_COUNT_FIELD = "promptTokenCount"; + private static final String CANDIDATES_TOKEN_COUNT_FIELD = "candidatesTokenCount"; + private static final String TOTAL_TOKEN_COUNT_FIELD = "totalTokenCount"; + private static final String MODEL_VERSION_FIELD = "modelVersion"; + private static final String RESPONSE_ID_FIELD = "responseId"; + private static final String FUNCTION_CALL_FIELD = "functionCall"; + private static final String FUNCTION_NAME_FIELD = "name"; + private static final String FUNCTION_ARGS_FIELD = "args"; + + private static final String CHAT_COMPLETION_CHUNK = "chat.completion.chunk"; + private static final String FUNCTION_TYPE = "function"; + + private final BiFunction errorParser; + + public GoogleVertexAiUnifiedStreamingProcessor(BiFunction errorParser) { + this.errorParser = errorParser; + } + + @Override + protected void next(Deque events) throws Exception { + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = new ArrayDeque(events.size()); + + for (var event : events) { + try { + var completionChunk = parse(parserConfig, event.data()); + completionChunk.forEachRemaining(results::offer); + } catch (Exception e) { + var eventString = event.data(); + logger.warn("Failed to parse event from Google Vertex AI provider: {}", eventString); + throw errorParser.apply(eventString, e); + } + } + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } + } + + private Iterator parse( + XContentParserConfiguration parserConfig, + String event + ) throws IOException { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { + moveToFirstToken(jsonParser); + ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse(jsonParser); + return Collections.singleton(chunk).iterator(); + } + } + + public static class GoogleVertexAiChatCompletionChunkParser { + private static @Nullable StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage usageMetadataToChunk( + @Nullable UsageMetadata usage + ) { + if (usage == null) { + return null; + } + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage( + usage.candidatesTokenCount(), + usage.promptTokenCount(), + usage.totalTokenCount() + ); + } + + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice candidateToChoice(Candidate candidate) { + StringBuilder contentTextBuilder = new StringBuilder(); + List toolCalls = new ArrayList<>(); + + String role = null; + + var contentAndPartsAreNotEmpty = candidate.content() != null + && candidate.content().parts() != null + && candidate.content().parts().isEmpty() == false; + + if (contentAndPartsAreNotEmpty) { + role = candidate.content().role(); // Role is at the content level + for (Part part : candidate.content().parts()) { + if (part.text() != null) { + contentTextBuilder.append(part.text()); + } + if (part.functionCall() != null) { + FunctionCall fc = part.functionCall(); + var function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + fc.args(), + fc.name() + ); + toolCalls.add( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 0, // No explicit ID from VertexAI so we use 0 + function.name(), // VertexAI does not provide an id for the function call so we use the name + function, + FUNCTION_TYPE + ) + ); + } + } + } + + List finalToolCalls = toolCalls.isEmpty() + ? null + : toolCalls; + + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + contentTextBuilder.isEmpty() ? null : contentTextBuilder.toString(), + null, + role, + finalToolCalls + ); + + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, candidate.finishReason(), candidate.index()); + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("google_vertexai_chat_completion_chunk", true, args -> { + List candidates = (List) args[0]; + UsageMetadata usage = (UsageMetadata) args[1]; + String modelversion = (String) args[2]; + String responseId = (String) args[3]; + + boolean candidatesIsEmpty = candidates == null || candidates.isEmpty(); + List choices = candidatesIsEmpty + ? Collections.emptyList() + : candidates.stream().map(GoogleVertexAiChatCompletionChunkParser::candidateToChoice).toList(); + + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + responseId, + choices, + modelversion, + CHAT_COMPLETION_CHUNK, + usageMetadataToChunk(usage) + ); + }); + + static { + PARSER.declareObjectArray( + ConstructingObjectParser.constructorArg(), + (p, c) -> CandidateParser.parse(p), + new ParseField(CANDIDATES_FIELD) + ); + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> UsageMetadataParser.parse(p), + new ParseField(USAGE_METADATA_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_VERSION_FIELD)); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(RESPONSE_ID_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + // --- Nested Parsers for Google Vertex AI structure --- + + private record Candidate(Content content, String finishReason, int index) {} + + private static class CandidateParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("candidate", true, args -> { + var content = (Content) args[0]; + var finishReason = (String) args[1]; + var index = args[2] == null ? 0 : (int) args[2]; + return new Candidate(content, finishReason, index); + }); + + static { + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> ContentParser.parse(p), + new ParseField(CONTENT_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(INDEX_FIELD)); + } + + public static Candidate parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record Content(String role, List parts) {} + + private static class ContentParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + CONTENT_FIELD, + true, + args -> new Content((String) args[0], (List) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ROLE_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.constructorArg(), + (p, c) -> PartParser.parse(p), + new ParseField(PARTS_FIELD) + ); + } + + public static Content parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record Part(@Nullable String text, @Nullable FunctionCall functionCall) {} // Modified + + private static class PartParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "part", + true, + args -> new Part((String) args[0], (FunctionCall) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TEXT_FIELD)); + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> FunctionCallParser.parse(p), + new ParseField(FUNCTION_CALL_FIELD) + ); + } + + public static Part parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record FunctionCall(String name, String args) {} + + private static class FunctionCallParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + FUNCTION_CALL_FIELD, + true, + args -> { + var name = (String) args[0]; + + @SuppressWarnings("unchecked") + var argsMap = (Map) args[1]; + if (argsMap == null) { + return new FunctionCall(name, null); + } + try { + var builder = XContentFactory.jsonBuilder().map(argsMap); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + return new FunctionCall(name, json); + } catch (IOException e) { + logger.warn("Failed to parse and convert VertexAI function args to json", e); + return new FunctionCall(name, null); + } + } + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(FUNCTION_NAME_FIELD)); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> p.map(), new ParseField(FUNCTION_ARGS_FIELD)); + } + + public static FunctionCall parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record UsageMetadata(int promptTokenCount, int candidatesTokenCount, int totalTokenCount) {} + + private static class UsageMetadataParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + USAGE_METADATA_FIELD, + true, + args -> { + if (Objects.isNull(args[0]) && Objects.isNull(args[1]) && Objects.isNull(args[2])) { + return null; + } + return new UsageMetadata( + args[0] == null ? 0 : (int) args[0], + args[1] == null ? 0 : (int) args[1], + args[2] == null ? 0 : (int) args[2] + ); + } + ); + + static { + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(PROMPT_TOKEN_COUNT_FIELD)); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CANDIDATES_TOKEN_COUNT_FIELD)); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TOTAL_TOKEN_COUNT_FIELD)); + } + + public static UsageMetadata parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index 627580facee72..2aa42a8ae69c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -9,11 +9,19 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; import java.util.Map; @@ -23,10 +31,16 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor { + public static final String COMPLETION_ERROR_PREFIX = "Google VertexAI chat completion"; private final Sender sender; private final ServiceComponents serviceComponents; + static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( + "Google VertexAI chat completion" + ); + static final String USER_ROLE = "user"; + public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); @@ -50,4 +64,19 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings) { + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java index 7ae0eaa9d8bfb..eaa71f2646efe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; @@ -18,4 +19,6 @@ public interface GoogleVertexAiActionVisitor { ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map taskSettings); ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings); + + ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java new file mode 100644 index 0000000000000..301d8f1075502 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; +import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel { + public GoogleVertexAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context), + new EmptyTaskSettings(), + GoogleVertexAiSecretSettings.fromMap(secrets) + ); + } + + GoogleVertexAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + GoogleVertexAiChatCompletionServiceSettings serviceSettings, + EmptyTaskSettings taskSettings, + @Nullable GoogleVertexAiSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + + var newServiceSettings = new GoogleVertexAiChatCompletionServiceSettings( + originalModelServiceSettings.projectId(), + originalModelServiceSettings.location(), + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.rateLimitSettings() + ); + + return new GoogleVertexAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + newServiceSettings, + model.getTaskSettings(), + model.getSecretSettings() + ); + } + + @Override + public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } + + @Override + public GoogleDiscoveryEngineRateLimitServiceSettings rateLimitServiceSettings() { + return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings(); + } + + @Override + public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() { + return (GoogleVertexAiChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public EmptyTaskSettings getTaskSettings() { + return (EmptyTaskSettings) super.getTaskSettings(); + } + + @Override + public GoogleVertexAiSecretSettings getSecretSettings() { + return (GoogleVertexAiSecretSettings) super.getSecretSettings(); + } + + public static URI buildUri(String location, String projectId, String model) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) + .setPathSegments( + GoogleVertexAiUtils.V1, + GoogleVertexAiUtils.PROJECTS, + projectId, + GoogleVertexAiUtils.LOCATIONS, + GoogleVertexAiUtils.GLOBAL, + GoogleVertexAiUtils.PUBLISHERS, + GoogleVertexAiUtils.PUBLISHER_GOOGLE, + GoogleVertexAiUtils.MODELS, + format("%s:%s", model, GoogleVertexAiUtils.STREAM_GENERATE_CONTENT) + ) + .setCustomQuery(GoogleVertexAiUtils.QUERY_PARAM_ALT_SSE) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..3e6f2c5d45238 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; +import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; + +public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + GoogleDiscoveryEngineRateLimitServiceSettings { + + public static final String NAME = "google_vertex_ai_chatcompletion_service_settings"; + + private final String location; + private final String modelId; + private final String projectId; + + private final RateLimitSettings rateLimitSettings; + + // https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000); + + public GoogleVertexAiChatCompletionServiceSettings(StreamInput in) throws IOException { + this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in)); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.field(PROJECT_ID, projectId); + builder.field(LOCATION, location); + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + return builder; + } + + public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + // Extract required fields + String projectId = ServiceUtils.extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String location = ServiceUtils.extractRequiredString(map, LOCATION, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = ServiceUtils.extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + // Extract rate limit settings + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + GoogleVertexAiService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings); + } + + public GoogleVertexAiChatCompletionServiceSettings( + String projectId, + String location, + String modelId, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.projectId = projectId; + this.location = location; + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public String location() { + return location; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public String projectId() { + return projectId; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(projectId); + out.writeString(location); + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GoogleVertexAiChatCompletionServiceSettings that = (GoogleVertexAiChatCompletionServiceSettings) o; + return Objects.equals(location, that.location) + && Objects.equals(modelId, that.modelId) + && Objects.equals(projectId, that.projectId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(location, modelId, projectId, rateLimitSettings); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..7b20e71099e66 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class GoogleVertexAiUnifiedChatCompletionRequest implements GoogleVertexAiRequest { + + private final GoogleVertexAiChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + + public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + this.model = Objects.requireNonNull(model); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + decorateWithAuth(httpPost); + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public void decorateWithAuth(HttpPost httpPost) { + GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for Google VertexAI Chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Google VertexAI Chat completions + return null; + } + + @Override + public boolean isStreaming() { + return unifiedChatInput.stream(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..7b8f75b2853bb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,356 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXContentObject { + private static final String CONTENTS = "contents"; + private static final String ROLE = "role"; + private static final String PARTS = "parts"; + private static final String TEXT = "text"; + private static final String GENERATION_CONFIG = "generationConfig"; + private static final String TEMPERATURE = "temperature"; + private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens"; + private static final String TOP_P = "topP"; + + private static final String TOOLS = "tools"; + private static final String FUNCTION_DECLARATIONS = "functionDeclarations"; + private static final String FUNCTION_NAME = "name"; + private static final String FUNCTION_DESCRIPTION = "description"; + private static final String FUNCTION_PARAMETERS = "parameters"; + private static final String FUNCTION_TYPE = "function"; + private static final String TOOL_CONFIG = "toolConfig"; + private static final String FUNCTION_CALLING_CONFIG = "functionCallingConfig"; + private static final String TOOL_MODE = "mode"; + private static final String TOOL_MODE_ANY = "ANY"; + private static final String TOOL_MODE_AUTO = "auto"; + private static final String ALLOWED_FUNCTION_NAMES = "allowedFunctionNames"; + + private static final String FUNCTION_CALL = "functionCall"; + private static final String FUNCTION_CALL_NAME = "name"; + private static final String FUNCTION_CALL_ARGS = "args"; + + private final UnifiedChatInput unifiedChatInput; + + private static final String USER_ROLE = "user"; + private static final String MODEL_ROLE = "model"; + private static final String ASSISTANT_ROLE = "assistant"; + private static final String SYSTEM_ROLE = "system"; + private static final String TOOL_ROLE = "tool"; + private static final String STOP_SEQUENCES = "stopSequences"; + + private static final String SYSTEM_INSTRUCTION = "systemInstruction"; + + public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + } + + private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) { + var messageRoleLowered = messageRole.toLowerCase(Locale.ROOT); + + if (messageRoleLowered.equals(USER_ROLE)) { + return USER_ROLE; + } else if (messageRole.equals(ASSISTANT_ROLE)) { + // Gemini VertexAI API does not use "assistant". Instead, it uses "model" + return MODEL_ROLE; + } else if (messageRole.equals(TOOL_ROLE)) { + // Gemini VertexAI does not have the tool role, so we map it to "model" + return MODEL_ROLE; + } + + var errorMessage = format( + "Role [%s] not supported by Google VertexAI ChatCompletion. Supported roles: [%s, %s]", + messageRole, + USER_ROLE, + ASSISTANT_ROLE + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + private void validateAndAddContentObjectsToBuilder(XContentBuilder builder, UnifiedCompletionRequest.ContentObjects contentObjects) + throws IOException { + + for (var contentObject : contentObjects.contentObjects()) { + if (contentObject.type().equals(TEXT) == false) { + var errorMessage = format( + "Type [%s] not supported by Google VertexAI ChatCompletion. Supported types: [text]", + contentObject.type() + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + if (contentObject.text().isEmpty()) { + return; // VertexAI API does not support empty text parts + } + + // We are only supporting Text messages for now + builder.startObject(); + builder.field(TEXT, contentObject.text()); + builder.endObject(); + } + + } + + private static Map jsonStringToMap(String jsonString) throws IOException { + if (jsonString == null || jsonString.isEmpty()) { + return null; + } + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, jsonString)) { + XContentParser.Token token = parser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser); + return parser.mapStrings(); + } + } + + private void buildSystemInstruction(XContentBuilder builder) throws IOException { + var messages = unifiedChatInput.getRequest().messages(); + var systemMessages = messages.stream().filter(message -> message.role().equalsIgnoreCase(SYSTEM_ROLE)).toList(); + + if (systemMessages.isEmpty()) { + return; + } + + builder.startObject(SYSTEM_INSTRUCTION); + { + builder.startArray(PARTS); + for (var systemMessage : systemMessages) { + if (systemMessage.content() instanceof UnifiedCompletionRequest.ContentString contentString) { + if (contentString.content().isEmpty()) { + var errorMessage = "System message cannot be empty for Google Vertex AI"; + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + builder.startObject(); + builder.field(TEXT, contentString.content()); + builder.endObject(); + } else if (systemMessage.content() instanceof UnifiedCompletionRequest.ContentObjects contentObjects) { + for (var contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT, contentObject.text()); + builder.endObject(); + } + } else { + var errorMessage = "Only text system instructions are supported for Vertex AI"; + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + } + builder.endArray(); + } + builder.endObject(); + + } + + private void buildContents(XContentBuilder builder) throws IOException { + var messages = unifiedChatInput.getRequest().messages(); + + builder.startArray(CONTENTS); + for (UnifiedCompletionRequest.Message message : messages) { + if (message.role().equalsIgnoreCase(SYSTEM_ROLE)) { + // System messages are built in another method + continue; + } + + builder.startObject(); + builder.field(ROLE, messageRoleToGoogleVertexAiSupportedRole(message.role())); + builder.startArray(PARTS); + { + if (message.content() instanceof UnifiedCompletionRequest.ContentString) { + UnifiedCompletionRequest.ContentString contentString = (UnifiedCompletionRequest.ContentString) message.content(); + // VertexAI does not support empty text parts + if (contentString.content().isEmpty() == false) { + builder.startObject(); + builder.field(TEXT, contentString.content()); + builder.endObject(); + } + } else if (message.content() instanceof UnifiedCompletionRequest.ContentObjects) { + UnifiedCompletionRequest.ContentObjects contentObjects = (UnifiedCompletionRequest.ContentObjects) message.content(); + validateAndAddContentObjectsToBuilder(builder, contentObjects); + } + + if (message.toolCalls() != null && message.toolCalls().isEmpty() == false) { + var toolCalls = message.toolCalls(); + for (var toolCall : toolCalls) { + builder.startObject(); + { + builder.startObject(FUNCTION_CALL); + builder.field(FUNCTION_CALL_NAME, toolCall.function().name()); + builder.field(FUNCTION_CALL_ARGS, jsonStringToMap(toolCall.function().arguments())); + builder.endObject(); + } + builder.endObject(); + } + } + } + builder.endArray(); + builder.endObject(); + } + builder.endArray(); + } + + private void buildTools(XContentBuilder builder) throws IOException { + var request = unifiedChatInput.getRequest(); + + var tools = request.tools(); + if (tools == null || tools.isEmpty()) { + return; + } + + builder.startArray(TOOLS); + { + builder.startObject(); + builder.startArray(FUNCTION_DECLARATIONS); + for (var tool : tools) { + if (FUNCTION_TYPE.equals(tool.type()) == false) { + var errorMessage = format( + "Tool type [%s] not supported by Google VertexAI ChatCompletion. Supported types: [%s]", + tool.type(), + FUNCTION_TYPE + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + var function = tool.function(); + if (function == null) { + var errorMessage = format("Tool of type [%s] must have a function definition", tool.type()); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + builder.startObject(); + builder.field(FUNCTION_NAME, function.name()); + if (Strings.hasText(function.description())) { + builder.field(FUNCTION_DESCRIPTION, function.description()); + } + + if (function.parameters() != null && function.parameters().isEmpty() == false) { + builder.field(FUNCTION_PARAMETERS, function.parameters()); + } + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + } + builder.endArray(); + } + + private void buildToolConfig(XContentBuilder builder) throws IOException { + var request = unifiedChatInput.getRequest(); + + UnifiedCompletionRequest.ToolChoiceObject toolChoice; + if (request.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + UnifiedCompletionRequest.ToolChoiceObject toolChoiceObject = (UnifiedCompletionRequest.ToolChoiceObject) request.toolChoice(); + toolChoice = toolChoiceObject; + } else if (request.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + UnifiedCompletionRequest.ToolChoiceString toolChoiceString = (UnifiedCompletionRequest.ToolChoiceString) request.toolChoice(); + if (toolChoiceString.value().equals(TOOL_MODE_AUTO)) { + return; + } + throw new ElasticsearchStatusException( + format( + "Tool choice value [%s] not supported by Google VertexAI ChatCompletion. Supported values: [%s]", + toolChoiceString.value(), + TOOL_MODE_AUTO + ), + RestStatus.BAD_REQUEST + ); + } else { + return; + } + if (FUNCTION_TYPE.equals(toolChoice.type()) == false) { + var errorMessage = format( + "Tool choice type [%s] not supported by Google VertexAI ChatCompletion. Supported types: [%s]", + toolChoice.type(), + FUNCTION_TYPE + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + builder.startObject(TOOL_CONFIG); + builder.startObject(FUNCTION_CALLING_CONFIG); + + var chosenFunction = toolChoice.function(); + if (chosenFunction != null) { + // If we are using toolChoice we set the API to use the 'ANY', meaning that the model will call this tool + // We do that since it's the only supported way right now to make compatible the OpenAi spec with VertexAI spec + builder.field(TOOL_MODE, TOOL_MODE_ANY); + if (Strings.hasText(chosenFunction.name())) { + builder.startArray(ALLOWED_FUNCTION_NAMES); + builder.value(chosenFunction.name()); + builder.endArray(); + } + + builder.endObject(); + builder.endObject(); + } + } + + private void buildGenerationConfig(XContentBuilder builder) throws IOException { + var request = unifiedChatInput.getRequest(); + + boolean hasAnyConfig = request.stop() != null + || request.temperature() != null + || request.maxCompletionTokens() != null + || request.topP() != null; + + if (hasAnyConfig == false) { + return; + } + + builder.startObject(GENERATION_CONFIG); + + if (request.stop() != null) { + builder.stringListField(STOP_SEQUENCES, request.stop()); + } + if (request.temperature() != null) { + builder.field(TEMPERATURE, request.temperature()); + } + if (request.maxCompletionTokens() != null) { + builder.field(MAX_OUTPUT_TOKENS, request.maxCompletionTokens()); + } + if (request.topP() != null) { + builder.field(TOP_P, request.topP()); + } + + builder.endObject(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + buildContents(builder); + buildGenerationConfig(builder); + buildTools(builder); + buildToolConfig(builder); + buildSystemInstruction(builder); + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java index 79335014007ac..7eda9c8b01cae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java @@ -35,6 +35,10 @@ public final class GoogleVertexAiUtils { public static final String RANK = "rank"; + public static final String STREAM_GENERATE_CONTENT = "streamGenerateContent"; + + public static final String QUERY_PARAM_ALT_SSE = "alt=sse"; + private GoogleVertexAiUtils() {} } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 8ea1c12ea9e4a..1e22ae20b3537 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; @@ -27,9 +28,10 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; @@ -45,14 +47,19 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.TimeUnit; +import static java.util.concurrent.TimeUnit.MINUTES; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -63,6 +70,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase { private ThreadPool threadPool; private HttpClientManager clientManager; + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @Before public void init() throws Exception { @@ -78,6 +86,54 @@ public void shutdown() throws IOException { webServer.close(); } + public void testParseRequestConfig_CreateGoogleVertexAiChatCompletionModel() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiChatCompletionModel.class)); + + var vertexAIModel = (GoogleVertexAiChatCompletionModel) model; + + assertThat(vertexAIModel.getServiceSettings().modelId(), is(modelId)); + assertThat(vertexAIModel.getServiceSettings().location(), is(location)); + assertThat(vertexAIModel.getServiceSettings().projectId(), is(projectId)); + assertThat(vertexAIModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + assertThat(vertexAIModel.getConfigurations().getTaskType(), equalTo(CHAT_COMPLETION)); + assertThat(vertexAIModel.getServiceSettings().rateLimitSettings().requestsPerTimeUnit(), equalTo(1000L)); + assertThat(vertexAIModel.getServiceSettings().rateLimitSettings().timeUnit(), equalTo(MINUTES)); + + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(serviceAccountJson) + ), + modelListener + ); + } + } + public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { var projectId = "project"; var location = "location"; @@ -427,6 +483,53 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM } } + public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatCompletionModel() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + getTaskSettingsMap(autoTruncate, InputType.INGEST), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.CHAT_COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiChatCompletionModel.class)); + + var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model; + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(chatCompletionModel.getServiceSettings().location(), is(location)); + assertThat(chatCompletionModel.getServiceSettings().projectId(), is(projectId)); + assertThat(chatCompletionModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + assertThat(chatCompletionModel.getConfigurations().getTaskType(), equalTo(CHAT_COMPLETION)); + assertThat(chatCompletionModel.getServiceSettings().rateLimitSettings().requestsPerTimeUnit(), equalTo(1000L)); + assertThat(chatCompletionModel.getServiceSettings().rateLimitSettings().timeUnit(), equalTo(MINUTES)); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { var projectId = "project"; var location = "location"; @@ -871,7 +974,7 @@ public void testGetConfiguration() throws Exception { { "service": "googlevertexai", "name": "Google Vertex AI", - "task_types": ["text_embedding", "rerank"], + "task_types": ["text_embedding", "rerank", "chat_completion"], "configurations": { "service_account_json": { "description": "API Key for the provider you're connecting to.", @@ -880,7 +983,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "project_id": { "description": "The GCP Project ID which has Vertex AI API(s) enabled. For more information on the URL, refer to the {geminiVertexAIDocs}.", @@ -889,7 +992,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "location": { "description": "Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.", @@ -898,7 +1001,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -907,7 +1010,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "model_id": { "description": "ID of the LLM you're using.", @@ -916,7 +1019,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] } } } @@ -938,7 +1041,9 @@ public void testGetConfiguration() throws Exception { } private GoogleVertexAiService createGoogleVertexAiService() { - return new GoogleVertexAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool)); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..a82484e3fef24 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleVertexAiUnifiedChatCompletionResponseHandlerTests extends ESTestCase { + private static final String INFERENCE_ID = "vertexAiInference"; + + private final GoogleVertexAiUnifiedChatCompletionResponseHandler responseHandler = + new GoogleVertexAiUnifiedChatCompletionResponseHandler("chat_completion"); + + public void testFailValidationWithAllErrorFields() throws IOException { + var responseJson = """ + { + "error": { + "code": 400, + "message": "Invalid JSON payload received.", + "status": "INVALID_ARGUMENT" + } + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(Strings.format(""" + {"error":{"code":"400","message":"Received a server error status code for request from inference entity id [%s] \ + status [500]. Error message: [Invalid JSON payload received.]","type":"INVALID_ARGUMENT"}}\ + """, INFERENCE_ID))); + } + + public void testFailValidationWithAllErrorFieldsAndDetails() throws IOException { + var responseJson = """ + { + "error": { + "code": 400, + "message": "Invalid JSON payload received.", + "status": "INVALID_ARGUMENT", + "details": [ + { "some":"value" } + ] + } + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(Strings.format(""" + {"error":{"code":"400","message":"Received a server error status code for request from inference entity id [%s] \ + status [500]. Error message: [Invalid JSON payload received.]","type":"INVALID_ARGUMENT"}}\ + """, INFERENCE_ID))); + } + + private static Request mockRequest() { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn(INFERENCE_ID); + when(request.isStreaming()).thenReturn(true); + return request; + } + + private static HttpResponse mockHttpResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String invalidResponseJson(String responseJson) throws IOException { + var exception = invalidResponse(responseJson); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } + + private Exception invalidResponse(String responseJson) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockHttpResponse(500), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java new file mode 100644 index 0000000000000..db8259a1dba55 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; + +public class GoogleVertexAiUnifiedStreamingProcessorTests extends ESTestCase { + + public void testJsonLiteral() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ + { "text" : "Elastic" }, + { + "functionCall": { + "name": "getWeatherData", + "args": { "unit": "celsius", "location": "buenos aires, argentina" } + } + } + ] + }, + "finishReason": "MAXTOKENS" + } ], + "usageMetadata" : { + "promptTokenCount": 10, + "candidatesTokenCount": 20, + "totalTokenCount": 30, + "trafficType" : "ON_DEMAND" + }, + "modelVersion" : "gemini-2.0-flash-lite", + "createTime" : "2025-05-07T14:36:16.122336Z", + "responseId" : "responseId" + } + """; + + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("responseId", chunk.id()); + assertEquals(1, chunk.choices().size()); + assertEquals("chat.completion.chunk", chunk.object()); + + var choice = chunk.choices().get(0); + assertEquals("Elastic", choice.delta().content()); + assertEquals("model", choice.delta().role()); + assertEquals("gemini-2.0-flash-lite", chunk.model()); + assertEquals(0, choice.index()); // VertexAI response does not have Index. Use 0 as default + assertEquals("MAXTOKENS", choice.finishReason()); + + assertEquals(1, choice.delta().toolCalls().size()); + var toolCall = choice.delta().toolCalls().get(0); + assertEquals("getWeatherData", toolCall.function().name()); + assertEquals("{\"unit\":\"celsius\",\"location\":\"buenos aires, argentina\"}", toolCall.function().arguments()); + + assertNotNull(chunk.usage()); + assertEquals(20, chunk.usage().completionTokens()); + assertEquals(10, chunk.usage().promptTokens()); + assertEquals(30, chunk.usage().totalTokens()); + + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } + } + + public void testJsonLiteral_usageMetadataTokenCountMissing() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ { "text" : "Hello" } ] + }, + "finishReason": "STOP" + } ], + "usageMetadata" : { + "trafficType" : "ON_DEMAND" + }, + "modelVersion": "gemini-2.0-flash-001", + "responseId": "responseId" + } + """; + + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("responseId", chunk.id()); + assertEquals(1, chunk.choices().size()); + var choice = chunk.choices().get(0); + assertEquals("Hello", choice.delta().content()); + assertEquals("model", choice.delta().role()); + assertEquals("STOP", choice.finishReason()); + assertEquals(0, choice.index()); + assertNull(choice.delta().toolCalls()); + + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } + } + + public void testJsonLiteral_functionCallArgsMissing() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ + { + "functionCall": { + "name": "getLocation" + } + } + ] + } + } ], + "responseId" : "resId789", + "modelVersion": "gemini-2.0-flash-00", + "usageMetadata" : { + "promptTokenCount": 10, + "candidatesTokenCount": 20, + "totalTokenCount": 30, + "trafficType" : "ON_DEMAND" + } + } + """; + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("resId789", chunk.id()); + assertEquals(1, chunk.choices().size()); + var choice = chunk.choices().get(0); + assertEquals("model", choice.delta().role()); + assertNull(choice.delta().content()); + + assertNotNull(choice.delta().toolCalls()); + assertEquals(1, choice.delta().toolCalls().size()); + var toolCall = choice.delta().toolCalls().get(0); + assertEquals("getLocation", toolCall.function().name()); + assertNull(toolCall.function().arguments()); + + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } + } + + public void testJsonLiteral_multipleTextParts() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ + { "text" : "This is the first part. " }, + { "text" : "This is the second part." } + ] + }, + "finishReason": "STOP" + } ], + "responseId" : "multiTextId", + "usageMetadata" : { + "promptTokenCount": 10, + "candidatesTokenCount": 20, + "totalTokenCount": 30, + "trafficType" : "ON_DEMAND" + }, + "modelVersion": "gemini-2.0-flash-001" + } + """; + + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("multiTextId", chunk.id()); + assertEquals(1, chunk.choices().size()); + + var choice = chunk.choices().get(0); + assertEquals("model", choice.delta().role()); + // Verify that the text from multiple parts is concatenated + assertEquals("This is the first part. This is the second part.", choice.delta().content()); + assertEquals("STOP", choice.finishReason()); + assertEquals(0, choice.index()); + assertNull(choice.delta().toolCalls()); + assertEquals("gemini-2.0-flash-001", chunk.model()); + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java new file mode 100644 index 0000000000000..58072b747a0aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.USER_ROLE; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiUnifiedChatCompletionActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + private static UnifiedChatInput createUnifiedChatInput(List messages, String role) { + boolean stream = true; + return new UnifiedChatInput(messages, role, stream); + } + + // Successful case would typically be tested via end-to-end notebook tests in AppEx repo + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listenerArg = invocation.getArgument(3); + listenerArg.onFailure(new IllegalStateException("failed")); + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); + } + + private ExecutableAction createAction(String location, String projectId, String actualModelId, Sender sender) { + var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( + projectId, + location, + actualModelId, + "api-key", + new RateLimitSettings(100) + ); + + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI chat completion"); + return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..13a60670f1bdf --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAIChatCompletionServiceSettingsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; + +public class GoogleVertexAIChatCompletionServiceSettingsTests extends InferenceSettingsTestCase< + GoogleVertexAiChatCompletionServiceSettings> { + @Override + protected Writeable.Reader instanceReader() { + return GoogleVertexAiChatCompletionServiceSettings::new; + } + + @Override + protected GoogleVertexAiChatCompletionServiceSettings fromMutableMap(Map mutableMap) { + return GoogleVertexAiChatCompletionServiceSettings.fromMap(mutableMap, ConfigurationParseContext.PERSISTENT); + + } + + @Override + protected GoogleVertexAiChatCompletionServiceSettings createTestInstance() { + return new GoogleVertexAiChatCompletionServiceSettings( + randomString(), + randomString(), + randomString(), + new RateLimitSettings(randomIntBetween(1, 1000)) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java new file mode 100644 index 0000000000000..fb5dccf89aa57 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +public class GoogleVertexAiChatCompletionModelTests extends ESTestCase { + + private static final String DEFAULT_PROJECT_ID = "test-project"; + private static final String DEFAULT_LOCATION = "us-central1"; + private static final String DEFAULT_MODEL_ID = "gemini-pro"; + private static final String DEFAULT_API_KEY = "test-api-key"; + private static final RateLimitSettings DEFAULT_RATE_LIMIT = new RateLimitSettings(100); + + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), + "gemini-flash", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("gemini-flash")); + + assertThat(overriddenModel, not(sameInstance(model))); + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); + assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is(DEFAULT_MODEL_ID)); + + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); + assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); + + assertThat(overriddenModel, not(sameInstance(model))); + } + + public void testBuildUri() throws URISyntaxException { + String location = "us-east1"; + String projectId = "my-gcp-project"; + String model = "gemini-1.5-flash-001"; + URI expectedUri = new URI( + "https://us-east1-aiplatform.googleapis.com/v1/projects/my-gcp-project" + + "/locations/global/publishers/google/models/gemini-1.5-flash-001:streamGenerateContent?alt=sse" + ); + URI actualUri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); + assertThat(actualUri, is(expectedUri)); + } + + public static GoogleVertexAiChatCompletionModel createCompletionModel( + String projectId, + String location, + String modelId, + String apiKey, + RateLimitSettings rateLimitSettings + ) { + return new GoogleVertexAiChatCompletionModel( + "google-vertex-ai-chat-test-id", + TaskType.CHAT_COMPLETION, + "google_vertex_ai", + new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings), + new EmptyTaskSettings(), + new GoogleVertexAiSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static URI buildDefaultUri() throws URISyntaxException { + return GoogleVertexAiChatCompletionModel.buildUri(DEFAULT_LOCATION, DEFAULT_PROJECT_ID, DEFAULT_MODEL_ID); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..261a6c2153b04 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,998 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.hamcrest.Matchers.containsString; + +public class GoogleVertexAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + private static final String USER_ROLE = "user"; + private static final String ASSISTANT_ROLE = "assistant"; + + public void testBasicSerialization_SingleMessage() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, Vertex AI!"), + USER_ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); // stream doesn't affect VertexAI request body + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello, Vertex AI!" + } + ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_MultipleMessages() throws IOException { + var messages = List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Previous user message."), + USER_ROLE, + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Previous model response."), + ASSISTANT_ROLE, + null, + null + ), + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Current user query."), USER_ROLE, null, null) + ); + + var unifiedRequest = UnifiedCompletionRequest.of(messages); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Previous user message." } ] + }, + { + "role": "model", + "parts": [ { "text": "Previous model response." } ] + }, + { + "role": "user", + "parts": [ { "text": "Current user query." } ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_Tools() throws IOException { + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + null, + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object", "description", "a description"), + null + ) + ) + ), + null + ); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "some text" } ] + } + ], + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "description": "a description" + } + } + ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_ToolsChoice() throws IOException { + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object", "description", "a description"), + null + ) + ) + ), + null + ); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "some text" } ] + } + ], + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "description": "a description" + } + } + ] + } + ], + "toolConfig": { + "functionCallingConfig" : { + "mode": "ANY", + "allowedFunctionNames": [ "some function" ] + } + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_WithAllGenerationConfig() throws IOException { + List messages = List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello Gemini!"), USER_ROLE, null, null) + ); + var completionRequestWithGenerationConfig = new UnifiedCompletionRequest( + messages, + "modelId", + 100L, + List.of("stop1", "stop2"), + 0.5f, + null, + null, + 0.9F + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Hello Gemini!" } ] + } + ], + "generationConfig": { + "stopSequences": ["stop1", "stop2"], + "temperature": 0.5, + "maxOutputTokens": 100, + "topP": 0.9 + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_WithSomeGenerationConfig() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Partial config."), + USER_ROLE, + null, + null + ); + var completionRequestWithGenerationConfig = new UnifiedCompletionRequest( + List.of(message), + "modelId", + 50L, + null, + 0.7f, + null, + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Partial config." } ] + } + ], + "generationConfig": { + "temperature": 0.7, + "maxOutputTokens": 50 + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_NoGenerationConfig() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("No extra config."), + USER_ROLE, + null, + null + ); + // No generation config fields set on unifiedRequest + var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "No extra config." } ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_WithContentObjects() throws IOException { + var contentObjects = List.of( + new UnifiedCompletionRequest.ContentObject("First part. ", "text"), + new UnifiedCompletionRequest.ContentObject("Second part.", "text") + ); + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(contentObjects), + USER_ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { "text": "First part. " }, + { "text": "Second part." } + ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testError_UnsupportedRole() throws IOException { + var unsupportedRole = "someUnexpectedRole"; + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Test"), + unsupportedRole, + null, + null + ); + var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + var builder = JsonXContent.contentBuilder(); + var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + assertEquals(RestStatus.BAD_REQUEST, statusException.status()); + var errorMessage = Strings.format("Role [%s] not supported by Google VertexAI ChatCompletion", unsupportedRole); + assertThat(statusException.toString(), containsString(errorMessage)); + } + + public void testError_UnsupportedContentObjectType() throws IOException { + var contentObjects = List.of(new UnifiedCompletionRequest.ContentObject("http://example.com/image.png", "image_url")); + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(contentObjects), + USER_ROLE, + null, + null + ); + var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + var builder = JsonXContent.contentBuilder(); + var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + assertEquals(RestStatus.BAD_REQUEST, statusException.status()); + assertThat(statusException.toString(), containsString("Type [image_url] not supported by Google VertexAI ChatCompletion")); + } + + public void testParseAllFields() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "some text" + }, + { + "functionCall": { + "name": "get_delivery_date", + "args": { + "order_id": "order_12345" + } + } + } + ] + } + ], + "generationConfig": { + "stopSequences": [ + "stop" + ], + "temperature": 0.1, + "maxOutputTokens": 100, + "topP": 0.2 + }, + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + ] + } + ], + "toolConfig": { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": [ + "some function" + ] + } + } + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + 100L, + List.of("stop"), + 0.1F, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + 0.2F + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testParseFunctionCallNoContent() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "model", + "parts": [ + { "functionCall" : { + "name": "get_delivery_date", + "args": { + "order_id" : "order_12345" + } + } + } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + null, + "tool", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testParseFunctionCallWithBadJson() throws IOException { + int someNumber = 1; + var illegalArguments = List.of("\"order_id\": \"order_12345\"}", "[]", Integer.toString(someNumber), "\"a\""); + for (var illegalArgument : illegalArguments) { + + var requestContentObject = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("", "text"))), + "assistant", + null, + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField(illegalArgument, "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(requestContentObject, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + + assertThrows(ParsingException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + } + + } + + public void testParseFunctionCallWithEmptyStringContent() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "model", + "parts": [ + { "functionCall" : { + "name": "get_delivery_date", + "args": { + "order_id" : "order_12345" + } + } + } + ] + } + ] + } + """; + + var requestContentObject = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("", "text"))), + "assistant", + null, + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + + var requestContentString = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(""), + "assistant", + null, + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + var requests = List.of(requestContentObject, requestContentString); + + for (var request : requests) { + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + } + + public void testParseToolChoiceString() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { "text": "some text" } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testBuildSystemMessage_MultipleParts() throws IOException { + String requestJson = """ + { + "systemInstruction": { + "parts": [ + { "text": "instruction text" }, + { "text": "instruction text2" } + ] + }, + "contents": [ + { + "role": "user", + "parts": [ + { "text": "some text" } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("instruction text", "text")) + ), + "system", + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("instruction text2", "text")) + ), + "system", + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testBuildSystemMessageMul() throws IOException { + String requestJson = """ + { + "systemInstruction": { + "parts": [ + { "text": "instruction text" } + ] + }, + "contents": [ + { + "role": "user", + "parts": [ + { "text": "some text" } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("instruction text", "text")) + ), + "system", + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testParseToolChoiceInvalid_throwElasticSearchStatusException() throws IOException { + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("unsupported"), + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + var statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + assertThat( + statusException.toString(), + containsString("Tool choice value [unsupported] not supported by Google VertexAI ChatCompletion.") + ); + + } + + public void testParseMultipleTools() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { "text": "some text" } + ] + } + ], + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + }, + { + "name": "get_current_temperature", + "description": "Get the current temperature in a location", + "parameters": { + "type": "object" + } + } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + null, + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ), + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current temperature in a location", + "get_current_temperature", + Map.of("type", "object"), + null + ) + ) + ), + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java new file mode 100644 index 0000000000000..b4064816f69d1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiUnifiedChatCompletionRequestTests extends ESTestCase { + + private static final String AUTH_HEADER_VALUE = "Bearer foo"; + + public void testCreateRequest_Default() throws IOException { + var modelId = "gemini-pro"; + var projectId = "test-project"; + var location = "us-central1"; + + var messages = List.of("Hello Gemini!"); + + var request = createRequest(projectId, location, modelId, messages, null, null); + var httpRequest = request.createHttpRequest(); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var uri = URI.create( + Strings.format( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers" + + "/google/models/%s:streamGenerateContent?alt=sse", + location, + projectId, + modelId + ) + ); + + assertThat(httpPost.getURI(), equalTo(uri)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap, + equalTo(Map.of("contents", List.of(Map.of("role", "user", "parts", List.of(Map.of("text", messages.get(0))))))) + ); + + } + + public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + UnifiedChatInput input, + GoogleVertexAiChatCompletionModel model + ) { + return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(input, model); + } + + public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + String projectId, + String location, + String modelId, + List messages, + @Nullable String apiKey, + @Nullable RateLimitSettings rateLimitSettings + ) { + var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( + projectId, + location, + modelId, + Objects.requireNonNullElse(apiKey, "default-api-key"), + Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)) + ); + var unifiedChatInput = new UnifiedChatInput(messages, "user", true); + + return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(unifiedChatInput, model); + } + + /** + * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} + */ + private static class GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest extends GoogleVertexAiUnifiedChatCompletionRequest { + GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + super(unifiedChatInput, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + } +}