From d429a3163ed56b7a8ed5c4b9d8bf9c8507a2058d Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 29 Apr 2025 13:51:29 -0300 Subject: [PATCH 01/38] VertexAI chat completion response entity with tests --- ...eVertexAiChatCompletionResponseEntity.java | 214 ++++++++++++++++++ ...exAiChatCompletionResponseEntityTests.java | 151 ++++++++++++ 2 files changed, 365 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java new file mode 100644 index 0000000000000..3022dde69c939 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java @@ -0,0 +1,214 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.response; + +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class GoogleVertexAiChatCompletionResponseEntity { + + private static final ParseField CANDIDATES = new ParseField("candidates"); + private static final ParseField CONTENT = new ParseField("content"); + private static final ParseField PARTS = new ParseField("parts"); + private static final ParseField TEXT = new ParseField("text"); + + /** + * Parses the Google Vertex AI chat completion response. + * For a request like + * + *
+     *     
+     *{
+     *   "contents": [
+     *     {
+     *       "role": "user",
+     *       "parts": [
+     *         {
+     *           "text": "Hello!"
+     *         }
+     *       ]
+     *     }
+     *   ],
+     *   "generationConfig": {
+     *     "responseModalities": [
+     *       "TEXT"
+     *     ],
+     *     "temperature": 1,
+     *     "maxOutputTokens": 8192,
+     *     "topP": 0.95
+     *   }
+     * }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *[{
+     *   "candidates": [
+     *     {
+     *       "content": {
+     *         "role": "model",
+     *         "parts": [
+     *           {
+     *             "text": "Hello there! How"
+     *           }
+     *         ]
+     *       }
+     *     }
+     *   ],
+     *   "usageMetadata": {
+     *     "trafficType": "ON_DEMAND"
+     *   },
+     *   "modelVersion": "gemini-2.0-flash-001",
+     *   "createTime": "2025-04-29T16:55:36.576032Z",
+     *   "responseId": "iAQRaKCUI_D7ld8Pq-aaaaa"
+     * }
+     * ,
+     * {
+     *   "candidates": [
+     *     {
+     *       "content": {
+     *         "role": "model",
+     *         "parts": [
+     *           {
+     *             "text": " can I help you today?\n"
+     *           }
+     *         ]
+     *       },
+     *       "finishReason": "STOP"
+     *     }
+     *   ],
+     *   "usageMetadata": {
+     *     "promptTokenCount": 2,
+     *     "candidatesTokenCount": 11,
+     *     "totalTokenCount": 13,
+     *     "trafficType": "ON_DEMAND",
+     *     "promptTokensDetails": [
+     *       {
+     *         "modality": "TEXT",
+     *         "tokenCount": 2
+     *       }
+     *     ],
+     *     "candidatesTokensDetails": [
+     *       {
+     *         "modality": "TEXT",
+     *         "tokenCount": 11
+     *       }
+     *     ]
+     *   },
+     *   "modelVersion": "gemini-2.0-flash-001",
+     *   "createTime": "2025-04-29T16:55:36.576032Z",
+     *   "responseId": "iAQRaKCUI_D7ld8Pq-aaaaa"
+     * }
+     * ]
+     *     
+     * 
+ */ + public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); + + StringBuilder fullText = new StringBuilder(); + + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + Chunk chunk = Chunk.PARSER.apply(parser, null); + chunk.extractText().ifPresent(fullText::append); + } + + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(fullText.toString()))); + } + } + + // --- Nested Records for Parsing --- + + public record Chunk(List candidates) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Chunk.class.getSimpleName(), + true, // Ignore unknown fields in the chunk object + args -> new Chunk((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), Candidate.PARSER::apply, CANDIDATES); + } + + public Optional extractText() { + return Optional.ofNullable(candidates) + .filter(list -> list.isEmpty() == false) + .map(List::getFirst) + .flatMap(Candidate::extractText); + } + } + + public record Candidate(Content content) { + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Candidate.class.getSimpleName(), + true, + args -> new Candidate((Content) args[0]) + ); + + static { + PARSER.declareObject(constructorArg(), Content.PARSER::apply, CONTENT); + } + + public Optional extractText() { + return Optional.ofNullable(content).flatMap(Content::extractText); + } + } + + public record Content(List parts) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Content.class.getSimpleName(), + true, + args -> new Content((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), Part.PARSER::apply, PARTS); + } + + public Optional extractText() { + return Optional.ofNullable(parts).filter(list -> list.isEmpty() == false).map(List::getFirst).map(Part::text); + } + } + + public record Part(String text) { + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Part.class.getSimpleName(), + true, + args -> new Part((String) args[0]) + ); + + static { + PARSER.declareString(optionalConstructorArg(), TEXT); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java new file mode 100644 index 0000000000000..36a365170eb3a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiChatCompletionResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForMultipleChunks() throws IOException { + String responseJson = """ + [ + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ { "text": "Hello " } ] + } + } + ] + }, + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ { "text": "World" } ] + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 2, "totalTokenCount": 7 }, + "modelVersion": "gemini-2.0-flash-001", + "createTime": "2025-04-29T14:32:55.843480Z", + "responseId": "F-MQaNi9M7OKqsMPmo2aaaa" + } + ] + """; + + ChatCompletionResults chatCompletionResults = GoogleVertexAiChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().getFirst().content(), is("Hello World")); + } + + public void testFromResponse_HandlesPartWithMissingText() throws IOException { + // Since text is optionalConstructorArg, missing text results in null, which is skipped by extractText + String responseJson = """ + [ + { + "candidates": [ + { + "content": { + "parts": [ { "not_text": "hello" } ] + } + } + ] + } + ] + """; + + ChatCompletionResults chatCompletionResults = GoogleVertexAiChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().getFirst().content(), is("")); + } + + public void testFromResponse_FailsWhenChunkMissingCandidates() { + // Parser ignores unknown fields, but expects 'candidates' for the constructor + String responseJson = """ + [ + { + "not_candidates": [] + } + ] + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, // ConstructingObjectParser throws this when required args are missing + () -> GoogleVertexAiChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + assertThat(thrownException.getMessage(), is("Required [candidates]")); + } + + public void testFromResponse_FailsWhenCandidateMissingContent() { + String responseJson = """ + [ + { + "candidates": [ + { "not_content": {} } + ] + } + ] + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> GoogleVertexAiChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + assertThat(thrownException.getMessage(), containsString("[Chunk] failed to parse field [candidates]")); + } + + public void testFromResponse_FailsWhenContentMissingParts() { + String responseJson = """ + [ + { + "candidates": [ + { "content": { "not_parts": [] } } + ] + } + ] + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> GoogleVertexAiChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + assertThat(thrownException.getMessage(), containsString("[Chunk] failed to parse field [candidates]")); + } + +} From 00bfdb09d9eabc3ab0232481191bc5bd5a2d2928 Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 29 Apr 2025 14:41:03 -0300 Subject: [PATCH 02/38] Modified build gradle to include google vertexai sdk --- gradle/verification-metadata.xml | 10 ++++++++++ x-pack/plugin/inference/build.gradle | 3 +++ 2 files changed, 13 insertions(+) diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index d546e80d1a8a4..cb6d2f779f3ef 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -481,6 +481,11 @@ + + + + + @@ -561,6 +566,11 @@ + + + + + diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index b0657968f00fc..dc0066634713d 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -107,6 +107,9 @@ dependencies { /* SLF4J (via AWS SDKv2) */ api "org.slf4j:slf4j-api:${versions.slf4j}" runtimeOnly "org.slf4j:slf4j-nop:${versions.slf4j}" + /* Google aiplatform SDK */ + implementation 'com.google.cloud:google-cloud-aiplatform:3.61.0' + api "com.google.api:gax:2.64.2" } tasks.named("dependencyLicenses").configure { From 23782704b90c838440c3f92a76b5e37817235cae Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 29 Apr 2025 15:12:29 -0300 Subject: [PATCH 03/38] Google vertex ai chat completion model with tests --- .../GoogleVertexAiChatCompletionModel.java | 141 ++++++++++++++++++ ...oogleVertexAiChatCompletionModelTests.java | 131 ++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java new file mode 100644 index 0000000000000..dc436fedca35d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; +import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; + +import java.net.URISyntaxException; +import java.util.Map; +import java.net.URI; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel { + public GoogleVertexAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context), + GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings), + GoogleVertexAiSecretSettings.fromMap(secrets) + ); + } + + GoogleVertexAiChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + GoogleVertexAiChatCompletionServiceSettings serviceSettings, + GoogleVertexAiChatCompletionTaskSettings taskSettings, + @Nullable GoogleVertexAiSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + + var newServiceSettings = new GoogleVertexAiChatCompletionServiceSettings( + originalModelServiceSettings.projectId(), + originalModelServiceSettings.location(), + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.rateLimitSettings() + ); + + return new GoogleVertexAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + newServiceSettings, + model.getTaskSettings(), + model.getSecretSettings() + ); + } + + public GoogleVertexAiChatCompletionModel( + ModelConfigurations configurations, + ModelSecrets secrets, + GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets, rateLimitServiceSettings); + } + + @Override + public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } + + @Override + public GoogleDiscoveryEngineRateLimitServiceSettings rateLimitServiceSettings() { + return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings(); + } + + @Override + public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() { + return (GoogleVertexAiChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() { + return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings(); + } + + @Override + public GoogleVertexAiSecretSettings getSecretSettings() { + return (GoogleVertexAiSecretSettings) super.getSecretSettings(); + } + + public static URI buildUri(String location, String projectId, String model) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) + .setPathSegments( + GoogleVertexAiUtils.V1, + GoogleVertexAiUtils.PROJECTS, + projectId, + GoogleVertexAiUtils.LOCATIONS, + GoogleVertexAiUtils.GLOBAL, + GoogleVertexAiUtils.PUBLISHERS, + GoogleVertexAiUtils.PUBLISHER_GOOGLE, + GoogleVertexAiUtils.MODELS, + format("%s:%s", model, GoogleVertexAiUtils.STREAM_GENERATE_CONTENT) + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java new file mode 100644 index 0000000000000..a4950911a804e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +public class GoogleVertexAiChatCompletionModelTests extends ESTestCase { + + private static final String DEFAULT_PROJECT_ID = "test-project"; + private static final String DEFAULT_LOCATION = "us-central1"; + private static final String DEFAULT_MODEL_ID = "gemini-pro"; + private static final String DEFAULT_API_KEY = "test-api-key"; + private static final RateLimitSettings DEFAULT_RATE_LIMIT = new RateLimitSettings(100); + + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), + "gemini-flash", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); + + // Check that the model ID is overridden + assertThat(overriddenModel.getServiceSettings().modelId(), is("gemini-flash")); + + // Check that other settings remain the same + assertThat(overriddenModel, not(sameInstance(model))); + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); + assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), + null, // Not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); + + // Check that the model ID is NOT overridden + assertThat(overriddenModel.getServiceSettings().modelId(), is(DEFAULT_MODEL_ID)); + // Check that other settings remain the same + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); + assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); // Task settings shouldn't change + // Since nothing changed in service settings, it *could* return the same instance, + // but the current implementation always creates a new one. Let's assert it's not the same. + assertThat(overriddenModel, not(sameInstance(model))); + } + + public void testBuildUri() throws URISyntaxException { + String location = "us-east1"; + String projectId = "my-gcp-project"; + String model = "gemini-1.5-flash-001"; + URI expectedUri = new URI( + "https://us-east1-aiplatform.googleapis.com/v1/projects/my-gcp-project/locations/global/publishers/google/models/gemini-1.5-flash-001:streamGenerateContent" + ); + URI actualUri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); + assertThat(actualUri, is(expectedUri)); + } + + public void testBuildUri_WithDifferentValues() throws URISyntaxException { + String location = "europe-west2"; + String projectId = "another-project-123"; + String model = "gemini-pro"; + URI expectedUri = new URI( + "https://europe-west2-aiplatform.googleapis.com/v1/projects/another-project-123/locations/global/publishers/google/models/gemini-pro:streamGenerateContent" + ); + URI actualUri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); + assertThat(actualUri, is(expectedUri)); + } + + public static GoogleVertexAiChatCompletionModel createCompletionModel( + String projectId, + String location, + String modelId, + String apiKey, + RateLimitSettings rateLimitSettings + ) { + return new GoogleVertexAiChatCompletionModel( + "google-vertex-ai-chat-test-id", + TaskType.CHAT_COMPLETION, + "google_vertex_ai", + new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings), + new GoogleVertexAiChatCompletionTaskSettings(), + new GoogleVertexAiSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static URI buildDefaultUri() throws URISyntaxException { + return GoogleVertexAiChatCompletionModel.buildUri(DEFAULT_LOCATION, DEFAULT_PROJECT_ID, DEFAULT_MODEL_ID); + } +} From 1f0097466e879cb4cf3e625821833798df3c4315 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 30 Apr 2025 11:20:52 -0300 Subject: [PATCH 04/38] Google vertex ai chat completion request with tests --- ...eVertexAiUnifiedChatCompletionRequest.java | 74 +++++++++++ ...xAiUnifiedChatCompletionRequestEntity.java | 121 ++++++++++++++++++ ...gleVertexAiChatCompletionRequestTests.java | 103 +++++++++++++++ 3 files changed, 298 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..48b2a5a4e0683 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class GoogleVertexAiUnifiedChatCompletionRequest implements GoogleVertexAiRequest { + + private final GoogleVertexAiChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + + public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + this.model = Objects.requireNonNull(model); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + decorateWithAuth(httpPost); + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public void decorateWithAuth(HttpPost httpPost) { + GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for Google VertexAI Chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Google VertexAI Chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..ddaa8ca91b4c4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXContentObject { + // Field names matching the Google Vertex AI API structure + private static final String CONTENTS = "contents"; + private static final String ROLE = "role"; + private static final String PARTS = "parts"; + private static final String TEXT = "text"; + private static final String GENERATION_CONFIG = "generationConfig"; + private static final String TEMPERATURE = "temperature"; + private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens"; + private static final String TOP_P = "topP"; + // TODO: Add other generationConfig fields if needed (e.g., stopSequences, topK) + + private final UnifiedChatInput unifiedChatInput; + private final GoogleVertexAiChatCompletionModel model; + + private static final String USER_ROLE = "user"; + private static final String MODEL_ROLE = "model"; + private static final String STOP_SEQUENCES = "stopSequences"; + + public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.model = Objects.requireNonNull(model); // Keep the model reference + } + + private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) throws IOException { + var messageRoleLowered = messageRole.toLowerCase(); + + if (messageRoleLowered.equals(USER_ROLE) || messageRoleLowered.equals(MODEL_ROLE)) { + return messageRoleLowered; + } + + // TODO: Here is OK to throw an IOException? + throw new IOException( + format( + "Role %s not supported by Google VertexAI ChatCompletion. Supported roles: '%s', '%s'", + messageRole, + USER_ROLE, + MODEL_ROLE + ) + ); + + } + + private void buildContents(XContentBuilder builder) throws IOException { + var messages = unifiedChatInput.getRequest().messages(); + + builder.startArray(CONTENTS); + for (UnifiedCompletionRequest.Message message : messages) { + builder.startObject(); + builder.field(ROLE, messageRoleToGoogleVertexAiSupportedRole(message.role())); + builder.startArray(PARTS); + builder.startObject(); + builder.field(TEXT, message.content().toString()); + builder.endObject(); + builder.endArray(); + builder.endObject(); + } + builder.endArray(); + } + + private void buildGenerationConfig(XContentBuilder builder) throws IOException { + var request = unifiedChatInput.getRequest(); + + boolean hasAnyConfig = request.stop() != null + || request.temperature() != null + || request.maxCompletionTokens() != null + || request.topP() != null; + + if (hasAnyConfig == false) { + return; + } + + builder.startObject(GENERATION_CONFIG); + + if (request.stop() != null) { + builder.stringListField(STOP_SEQUENCES, request.stop()); + } + if (request.temperature() != null) { + builder.field(TEMPERATURE, request.temperature()); + } + if (request.maxCompletionTokens() != null) { + builder.field(MAX_OUTPUT_TOKENS, request.maxCompletionTokens()); + } + if (request.topP() != null) { + builder.field(TOP_P, request.topP()); + } + + builder.endObject(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + buildContents(builder); + buildGenerationConfig(builder); + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java new file mode 100644 index 0000000000000..75a2eafa99b95 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiChatCompletionRequestTests extends ESTestCase { + + private static final String AUTH_HEADER_VALUE = "Bearer foo"; + + // TODO: add more test here to check the generation configuration, different role models, etc + + public void testCreateRequest_Default() throws IOException { + var modelId = "gemini-pro"; + var projectId = "test-project"; + var location = "us-central1"; + + var messages = List.of("Hello Gemini!"); + + var request = createRequest(projectId, location, modelId, messages, null, null); + var httpRequest = request.createHttpRequest(); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var uri = URI.create( + String.format( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", + location, + projectId, + modelId + ) + ); + + assertThat(httpPost.getURI(), equalTo(uri)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap, + equalTo(Map.of("contents", List.of(Map.of("role", "user", "parts", List.of(Map.of("text", messages.getFirst())))))) + ); + + } + + private static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + String projectId, + String location, + String modelId, + List messages, + @Nullable String apiKey, + @Nullable RateLimitSettings rateLimitSettings + ) { + var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( + projectId, + location, + modelId, + Objects.requireNonNullElse(apiKey, "default-api-key"), + Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)) + ); + var unifiedChatInput = new UnifiedChatInput(messages, "user", true); + + return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(unifiedChatInput, model); + } + + /** + * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} + */ + private static class GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest extends GoogleVertexAiUnifiedChatCompletionRequest { + GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + super(unifiedChatInput, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + } +} From 970ab3c988a25698e8930e59595a771a0b7ee496 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 30 Apr 2025 16:50:17 -0300 Subject: [PATCH 05/38] TransportVersion --- server/src/main/java/org/elasticsearch/TransportVersions.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index eaf3549bc83b1..bd6cf692c64f4 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -228,6 +228,7 @@ static TransportVersion def(int id) { public static final TransportVersion DENSE_VECTOR_OFF_HEAP_STATS = def(9_062_00_0); public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER = def(9_063_0_00); public static final TransportVersion SETTINGS_IN_DATA_STREAMS = def(9_064_0_00); + public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_065_0_00); /* * STOP! READ THIS FIRST! No, really, From 54280746343839cb905c4b761963310ff42d40ea Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 30 Apr 2025 16:53:26 -0300 Subject: [PATCH 06/38] ChatCompletion TaskSettings & ServiceSettings --- ...VertexAiChatCompletionServiceSettings.java | 158 ++++++++++++++++++ ...gleVertexAiChatCompletionTaskSettings.java | 56 +++++++ 2 files changed, 214 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..6f733ea71ef50 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; +import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; + +public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + GoogleDiscoveryEngineRateLimitServiceSettings { + + public static final String NAME = "google_vertex_ai_chatcompletion_service_settings"; + + // TODO: Other fields can be missing here. Mostly the ones the ones that are described here + // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent?_gl=1*15nhtzk*_up*MQ..&gclid=CjwKCAjwwqfABhBcEiwAZJjC3uBQNP9KUMZX8AGXvFXP2rIEQSfCX9RLP5gjzx5r-4xz1daBSxM7GBoCY64QAvD_BwE&gclsrc=aw.ds + private final String location; + private final String modelId; + private final String projectId; + + private final RateLimitSettings rateLimitSettings; + + // https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas + // TODO: this may be wrong. Double check before submitting the PR) + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2000); + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.field(PROJECT_ID, projectId); + builder.field(LOCATION, location); + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + return builder; + } + + public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + // Extract required fields + String projectId = ServiceUtils.extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String location = ServiceUtils.extractRequiredString(map, LOCATION, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = ServiceUtils.extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + // Extract rate limit settings + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + GoogleVertexAiService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings); + } + + public GoogleVertexAiChatCompletionServiceSettings( + String projectId, + String location, + String modelId, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.projectId = projectId; + this.location = location; + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public String location() { + return location; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public String projectId() { + return projectId; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(projectId); + out.writeString(location); + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GoogleVertexAiChatCompletionServiceSettings that = (GoogleVertexAiChatCompletionServiceSettings) o; + return Objects.equals(location, that.location) + && Objects.equals(modelId, that.modelId) + && Objects.equals(projectId, that.projectId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(location, modelId, projectId, rateLimitSettings); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java new file mode 100644 index 0000000000000..059b7d5028583 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.apache.commons.lang3.NotImplementedException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +// TODO: This class may no be needed. Keeping this class here to keep the compiler happy, but if not needed we could replace it with `EmptyTaskSettings` +public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings { + public static final String NAME = "google_vertex_ai_chatcompletion_task_settings"; + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + return null; + } + + @Override + public String getWriteableName() { + return ""; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return null; + } + + public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map map) { + return new GoogleVertexAiChatCompletionTaskSettings(); + } +} From ee44f22487568ec1701932cd0873771190aa85c2 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 30 Apr 2025 16:54:05 -0300 Subject: [PATCH 07/38] ChatCompletionRequestManager & tests --- ...oogleVertexAiCompletionRequestManager.java | 73 +++++++++++++++++++ ...gleVertexAiChatCompletionRequestTests.java | 9 ++- 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java new file mode 100644 index 0000000000000..d2fbaddd1db2b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity; + +import java.util.Objects; +import java.util.function.Supplier; + +public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager { + + private static final Logger logger = LogManager.getLogger(GoogleVertexAiCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler(); + + private static ResponseHandler createGoogleVertexAiResponseHandler() { + return new GoogleVertexAiResponseHandler( + "Google Vertex AI chat completion", + GoogleVertexAiChatCompletionResponseEntity::fromResponse + ); + } + + private final GoogleVertexAiChatCompletionModel model; + + private GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model, RateLimitGrouping.of(model)); + this.model = model; + } + + record RateLimitGrouping(int projectIdHash) { + public static RateLimitGrouping of(GoogleVertexAiChatCompletionModel model) { + Objects.requireNonNull(model); + return new RateLimitGrouping(model.rateLimitServiceSettings().projectId().hashCode()); + } + } + + public static GoogleVertexAiCompletionRequestManager of(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) { + Objects.requireNonNull(model); + Objects.requireNonNull(threadPool); + + return new GoogleVertexAiCompletionRequestManager(model, threadPool); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + var chatInputs = (UnifiedChatInput) inferenceInputs; + var request = new GoogleVertexAiUnifiedChatCompletionRequest(chatInputs, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java index 75a2eafa99b95..338b827a43920 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java @@ -67,7 +67,14 @@ public void testCreateRequest_Default() throws IOException { } - private static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + UnifiedChatInput input, + GoogleVertexAiChatCompletionModel model + ) { + return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(input, model); + } + + public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( String projectId, String location, String modelId, From 8160c2bf8a90a1f9de14e70b793322c57a6ca213 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 30 Apr 2025 16:55:14 -0300 Subject: [PATCH 08/38] VertexAI Service and related classes. WIP & missing tests --- .../googlevertexai/GoogleVertexAiService.java | 45 +++++++++++++++++-- .../action/GoogleVertexAiActionCreator.java | 11 +++++ .../action/GoogleVertexAiActionVisitor.java | 3 ++ .../request/GoogleVertexAiUtils.java | 2 + 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index e966ebc8d9e9b..a46880ce23b95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -29,6 +29,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -38,9 +40,11 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.EnumSet; @@ -48,6 +52,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -55,17 +60,26 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; +import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX; public class GoogleVertexAiService extends SenderService { public static final String NAME = "googlevertexai"; private static final String SERVICE_NAME = "Google Vertex AI"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.RERANK, + TaskType.CHAT_COMPLETION + ); + + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( + "google vertexai chat completion", + GoogleVertexAiChatCompletionResponseEntity::fromResponse + ); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -182,6 +196,7 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + // TODO: Since we added a task type we need to change this? @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; @@ -220,7 +235,20 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof GoogleVertexAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model; + var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest()); + + var manager = GoogleVertexAiCompletionRequestManager.of(updatedChatCompletionModel, getServiceComponents().threadPool()); + + var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + action.execute(inputs, timeout, listener); + } @Override @@ -320,6 +348,17 @@ private static GoogleVertexAiModel createModel( secretSettings, context ); + + case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index 627580facee72..5f1a703496ced 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -7,12 +7,15 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.action; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiCompletionRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; @@ -23,6 +26,7 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor { + public static final String COMPLETION_ERROR_PREFIX = "Google VertexAI chat completion"; private final Sender sender; private final ServiceComponents serviceComponents; @@ -50,4 +54,11 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings) { + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI chat completion"); + var requestManager = GoogleVertexAiCompletionRequestManager.of(model, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java index 7ae0eaa9d8bfb..eaa71f2646efe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; @@ -18,4 +19,6 @@ public interface GoogleVertexAiActionVisitor { ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map taskSettings); ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings); + + ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java index 79335014007ac..8e635da0f3052 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java @@ -35,6 +35,8 @@ public final class GoogleVertexAiUtils { public static final String RANK = "rank"; + public static final String STREAM_GENERATE_CONTENT = "streamGenerateContent"; + private GoogleVertexAiUtils() {} } From ff68fbe9455f95b5b5268ce2245636d443896744 Mon Sep 17 00:00:00 2001 From: lhoet Date: Mon, 5 May 2025 10:31:17 -0300 Subject: [PATCH 09/38] VertexAi ChatCompletion task settings fix. --- ...oogleVertexAiChatCompletionTaskSettings.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java index 059b7d5028583..7f78d83198db7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java @@ -9,6 +9,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xcontent.XContentBuilder; @@ -22,32 +23,32 @@ public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings { @Override public boolean isEmpty() { - return false; + return true; } @Override public TaskSettings updatedTaskSettings(Map newSettings) { - return null; + return this; } @Override public String getWriteableName() { - return ""; + return NAME; } @Override public TransportVersion getMinimalSupportedVersion() { - return null; + return TransportVersions.ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED; } @Override - public void writeTo(StreamOutput out) throws IOException { - - } + public void writeTo(StreamOutput out) throws IOException {} @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return null; + builder.startObject(); + builder.endObject(); + return builder; } public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map map) { From 29c7093ccf8c3bb1f1dbb2383894215ccefb574b Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 6 May 2025 16:06:37 -0300 Subject: [PATCH 10/38] JsonArrayParts event processor & parser --- .../streaming/JsonArrayPartsEventParser.java | 86 +++++ .../JsonArrayPartsEventProcessor.java | 40 +++ .../streaming/JsonPartsEventParserTests.java | 296 ++++++++++++++++++ 3 files changed, 422 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParser.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParser.java new file mode 100644 index 0000000000000..f052a7591d1d7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParser.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; + +/** + * Parses a stream of bytes that form a JSON array, where each element of the array + * is a JSON object. This parser extracts each complete JSON object from the array + * and emits it as byte array. + * + * Example of an expected stream: + * Chunk 1: [{"key":"val1"} + * Chunk 2: ,{"key2":"val2"} + * Chunk 3: ,{"key3":"val3"}, {"some":"object"}] + * + * This parser would emit four byte arrays, with data: + * 1. {"key":"val1"} + * 2. {"key2":"val2"} + * 3. {"key3":"val3"} + * 4. {"some":"object"} + */ +public class JsonArrayPartsEventParser { + + // Buffer to hold bytes from the previous call if they formed an incomplete JSON object. + private final ByteArrayOutputStream incompletePart = new ByteArrayOutputStream(); + + public Deque parse(byte[] newBytes) { + if (newBytes == null || newBytes.length == 0) { + return new ArrayDeque<>(0); + } + + ByteArrayOutputStream currentStream = new ByteArrayOutputStream(); + try { + currentStream.write(incompletePart.toByteArray()); + currentStream.write(newBytes); + } catch (IOException e) { + throw new UncheckedIOException("Error handling byte array streams", e); + } + incompletePart.reset(); + + byte[] dataToProcess = currentStream.toByteArray(); + return parseInternal(dataToProcess); + } + + private Deque parseInternal(byte[] data) { + int localBraceLevel = 0; + int objectStartIndex = -1; + Deque completedObjects = new ArrayDeque<>(); + + for (int i = 0; i < data.length; i++) { + char c = (char) data[i]; + + if (c == '{') { + if (localBraceLevel == 0) { + objectStartIndex = i; + } + localBraceLevel++; + } else if (c == '}') { + if (localBraceLevel > 0) { + localBraceLevel--; + if (localBraceLevel == 0) { + byte[] jsonObject = Arrays.copyOfRange(data, objectStartIndex, i + 1); + completedObjects.offer(jsonObject); + objectStartIndex = -1; + } + } + } + } + + if (localBraceLevel > 0) { + incompletePart.write(data, objectStartIndex, data.length - objectStartIndex); + } + return completedObjects; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java new file mode 100644 index 0000000000000..f66f415dbb266 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.util.Deque; + +public class JsonArrayPartsEventProcessor extends DelegatingProcessor> { + private final JsonArrayPartsEventParser jsonArrayPartsEventParser; + + // TODO: This class is missing unit testing + public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEventParser) { + this.jsonArrayPartsEventParser = jsonArrayPartsEventParser; + } + + @Override + public void next(HttpResult item) { + if (item.isBodyEmpty()) { + // discard empty result and go to the next + upstream().request(1); + return; + } + + var response = jsonArrayPartsEventParser.parse(item.body()); + if (response.isEmpty()) { + // discard empty result and go to the next + upstream().request(1); + return; + } + + downstream().onNext(response); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java new file mode 100644 index 0000000000000..43af0ca3c691f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java @@ -0,0 +1,296 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.test.ESTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.Deque; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class JsonPartsEventParserTests extends ESTestCase { + + private void assertJsonParts(Deque actualParts, List expectedJsonStrings) { + assertThat("Number of parsed parts mismatch", actualParts.size(), equalTo(expectedJsonStrings.size())); + var expectedIter = expectedJsonStrings.iterator(); + actualParts.forEach(part -> { + String actualJsonString = new String(part, StandardCharsets.UTF_8); + assertThat(actualJsonString, equalTo(expectedIter.next())); + }); + } + + public void testParse_givenNullOrEmptyBytes_returnsEmptyDeque() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + assertTrue(parser.parse(null).isEmpty()); + assertTrue(parser.parse(new byte[0]).isEmpty()); + + // Test with pre-existing incomplete part + parser.parse("{".getBytes(StandardCharsets.UTF_8)); // Create an incomplete part + assertTrue(parser.parse(null).isEmpty()); + assertTrue(parser.parse(new byte[0]).isEmpty()); + // Check that the incomplete part is still there + Deque parts = parser.parse("}".getBytes(StandardCharsets.UTF_8)); + assertJsonParts(parts, List.of("{}")); + } + + public void testParse_singleCompleteObject_returnsOnePart() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json = "{\"key\":\"value\"}"; + byte[] input = json.getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json)); + } + + public void testParse_multipleCompleteObjectsInOneChunk_returnsMultipleParts() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"key1\":\"value1\"}"; + String json2 = "{\"key2\":\"value2\"}"; + // Simulating a JSON array structure, the parser extracts {} + byte[] input = ("[" + json1 + "," + json2 + "]").getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); + } + + public void testParse_twoObjectsBackToBack_extractsBoth() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"a\":1}"; + String json2 = "{\"b\":2}"; + byte[] input = (json1 + json2).getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); + } + + public void testParse_objectSplitAcrossChunks_returnsOnePartAfterAllChunks() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json = "{\"key\":\"very_long_value\"}"; + byte[] chunk1 = "{\"key\":\"very_long".getBytes(StandardCharsets.UTF_8); + byte[] chunk2 = "_value\"}".getBytes(StandardCharsets.UTF_8); + + Deque parts1 = parser.parse(chunk1); + assertTrue("Expected no parts from incomplete chunk", parts1.isEmpty()); + + Deque parts2 = parser.parse(chunk2); + assertJsonParts(parts2, List.of(json)); + } + + public void testParse_multipleObjectsSomeSplit_returnsPartsIncrementally() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"id\":1,\"name\":\"first\"}"; + String json2 = "{\"id\":2,\"name\":\"second_is_longer\"}"; + String json3 = "{\"id\":3,\"name\":\"third\"}"; + + // Chunk 1: [{"id":1,"name":"first"},{"id":2,"name":"sec + byte[] chunk1 = ("[" + json1 + ",{\"id\":2,\"name\":\"sec").getBytes(StandardCharsets.UTF_8); + Deque parts1 = parser.parse(chunk1); + assertJsonParts(parts1, List.of(json1)); + + // Chunk 2: ond_is_longer"},{"id":3,"name":"third"}] + byte[] chunk2 = ("ond_is_longer\"}," + json3 + "]").getBytes(StandardCharsets.UTF_8); + Deque parts2 = parser.parse(chunk2); + assertJsonParts(parts2, List.of(json2, json3)); + + assertTrue("Expected no more parts from empty call", parser.parse(new byte[0]).isEmpty()); + } + + public void testParse_withArrayBracketsAndCommas_extractsObjects() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"a\":1}"; + String json2 = "{\"b\":2}"; + byte[] input = (" [ " + json1 + " , " + json2 + " ] ").getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); + } + + public void testParse_nestedObjects_extractsTopLevelObject() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json = "{\"outer_key\":{\"inner_key\":\"value\"},\"another_key\":\"val\"}"; + byte[] input = json.getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json)); + } + + public void testParse_nestedObjectSplit_extractsTopLevelObject() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json = "{\"outer_key\":{\"inner_key\":\"value\"},\"another_key\":\"val\"}"; + byte[] chunk1 = "{\"outer_key\":{\"inner_key\":\"val".getBytes(StandardCharsets.UTF_8); + byte[] chunk2 = "ue\"},\"another_key\":\"val\"}".getBytes(StandardCharsets.UTF_8); + + Deque parts1 = parser.parse(chunk1); + assertTrue(parts1.isEmpty()); + + Deque parts2 = parser.parse(chunk2); + assertJsonParts(parts2, List.of(json)); + } + + public void testParse_endsWithIncompleteObject_buffersCorrectly() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"complete\":\"done\"}"; + String partialJsonStart = "{\"incomplete_start\":\""; + + byte[] input = (json1 + "," + partialJsonStart).getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1)); // Only the complete one + + String partialJsonEnd = "continue\"}"; + String json2 = partialJsonStart + partialJsonEnd; + byte[] nextChunk = partialJsonEnd.getBytes(StandardCharsets.UTF_8); + parts = parser.parse(nextChunk); + assertJsonParts(parts, List.of(json2)); + } + + public void testParse_onlyOpenBrace_buffers() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + byte[] input = "{".getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertTrue(parts.isEmpty()); + + byte[] nextInput = "\"key\":\"val\"}".getBytes(StandardCharsets.UTF_8); + parts = parser.parse(nextInput); + assertJsonParts(parts, List.of("{\"key\":\"val\"}")); + } + + public void testParse_onlyCloseBrace_ignored() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + byte[] input = "}".getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertTrue(parts.isEmpty()); // Should be ignored as no open brace context + + // With preceding data + parts = parser.parse("some data }".getBytes(StandardCharsets.UTF_8)); + assertTrue(parts.isEmpty()); + } + + public void testParse_mismatchedBraces_handlesGracefully() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + // Extra closing brace + byte[] input1 = "{\"key\":\"val\"}}".getBytes(StandardCharsets.UTF_8); + Deque parts1 = parser.parse(input1); + assertJsonParts(parts1, List.of("{\"key\":\"val\"}")); // First object is fine, extra '}' ignored + + // Extra opening brace at end + parser = new JsonArrayPartsEventParser(); // reset + byte[] input2 = "{\"key\":\"val\"}{".getBytes(StandardCharsets.UTF_8); + Deque parts2 = parser.parse(input2); + assertJsonParts(parts2, List.of("{\"key\":\"val\"}")); // First object + // The last '{' should be buffered + Deque parts3 = parser.parse("}".getBytes(StandardCharsets.UTF_8)); + assertJsonParts(parts3, List.of("{}")); // Completes the buffered '{' + } + + public void testParse_objectWithMultiByteChars_handlesCorrectly() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json = "{\"key\":\"value_with_emoji_😊_and_résumé\"}"; + byte[] input = json.getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json)); + + // Split case + parser = new JsonArrayPartsEventParser(); // reset + String part1Str = "{\"key\":\"value_with_emoji_😊"; // Split within multi-byte char or after + String part2Str = "_and_résumé\"}"; + byte[] chunk1 = part1Str.getBytes(StandardCharsets.UTF_8); + byte[] chunk2 = part2Str.getBytes(StandardCharsets.UTF_8); + + Deque parts1 = parser.parse(chunk1); + assertTrue(parts1.isEmpty()); + + Deque parts2 = parser.parse(chunk2); + assertJsonParts(parts2, List.of(json)); + } + + public void testParse_javadocExampleStream() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"key\":\"val1\"}"; + String json2 = "{\"key2\":\"val2\"}"; + String json3 = "{\"key3\":\"val3\"}"; + String json4 = "{\"some\":\"object\"}"; + + // Chunk 1: [{"key":"val1"} + Deque parts1 = parser.parse(("[{\"key\":\"val1\"}").getBytes(StandardCharsets.UTF_8)); + assertJsonParts(parts1, List.of(json1)); + + // Chunk 2: ,{"key2":"val2"} + Deque parts2 = parser.parse((",{\"key2\":\"val2\"}").getBytes(StandardCharsets.UTF_8)); + assertJsonParts(parts2, List.of(json2)); + + // Chunk 3: ,{"key3":"val3"}, {"some":"object"}] + Deque parts3 = parser.parse((",{\"key3\":\"val3\"}, {\"some\":\"object\"}]").getBytes(StandardCharsets.UTF_8)); + assertJsonParts(parts3, List.of(json3, json4)); + } + + public void testParse_emptyObjects() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{}"; + String json2 = "{\"a\":{}}"; + byte[] input = (json1 + " " + json2).getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); + } + + public void testParse_dataBeforeFirstObjectAndAfterLastObject() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + String json1 = "{\"key1\":\"value1\"}"; + String json2 = "{\"key2\":\"value2\"}"; + byte[] input = ("leading_garbage" + json1 + "middle_garbage" + json2 + "trailing_garbage").getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); + } + + public void testParse_incompleteObjectNeverCompleted() { + JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); + byte[] chunk1 = "{\"key\":".getBytes(StandardCharsets.UTF_8); + Deque parts1 = parser.parse(chunk1); + assertTrue(parts1.isEmpty()); + + // Send another chunk that doesn't complete the first object but starts a new one + byte[] chunk2 = "{\"anotherKey\":\"value\"}".getBytes(StandardCharsets.UTF_8); + Deque parts2 = parser.parse(chunk2); + // The incomplete "{\"key\":" is overwritten by the new complete object "{\"anotherKey\":\"value\"}" + // because objectStartIndex will be reset to the start of the new object. + // The previous incompletePart is combined, but if a new '{' is found at brace level 0, + // objectStartIndex is updated. The old incomplete part is effectively discarded if not completed. + // Let's trace: + // After chunk1: incompletePart = "{\"key\":" + // parse(chunk2): dataToProcess = "{\"key\":{\"anotherKey\":\"value\"}" + // incompletePart.reset() + // Loop: + // '{' -> objectStartIndex=0, braceLevel=1 + // ... + // ':' -> + // '{' -> objectStartIndex=7 (THIS IS THE KEY: if braceLevel is >0, objectStartIndex is NOT reset) + // So the outer object is still being tracked. + // '}' -> braceLevel becomes 1 (for inner) + // '}' -> braceLevel becomes 0 (for outer) -> emits "{\"key\":{\"anotherKey\":\"value\"}}" + // This means the test case needs to be: + // Chunk1: {"key": + // Chunk2: "value"} , {"next":1} + // Expected: {"key":"value"}, {"next":1} + + // Corrected test for incomplete object handling: + parser = new JsonArrayPartsEventParser(); // Reset + parts1 = parser.parse("{\"key\":".getBytes(StandardCharsets.UTF_8)); + assertTrue(parts1.isEmpty()); + + Deque partsAfterCompletion = parser.parse("\"value\"}".getBytes(StandardCharsets.UTF_8)); + assertJsonParts(partsAfterCompletion, List.of("{\"key\":\"value\"}")); + + // If an incomplete part is followed by non-JSON or unrelated data + parser = new JsonArrayPartsEventParser(); // Reset + parts1 = parser.parse("{\"key\":".getBytes(StandardCharsets.UTF_8)); + assertTrue(parts1.isEmpty()); + // Send some data that doesn't complete it and doesn't start a new valid object + Deque partsNoCompletion = parser.parse("some other data without braces".getBytes(StandardCharsets.UTF_8)); + assertTrue(partsNoCompletion.isEmpty()); + // The incomplete part should still be "{\"key\":some other data without braces" + // Now complete it + Deque finalParts = parser.parse("}".getBytes(StandardCharsets.UTF_8)); + assertJsonParts(finalParts, List.of("{\"key\":some other data without braces}")); + } +} From bfd75b051970f81d8b07a77756254fb9617644cb Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 6 May 2025 16:12:15 -0300 Subject: [PATCH 11/38] AI Service and service tests --- .../googlevertexai/GoogleVertexAiService.java | 14 ++--- .../GoogleVertexAiServiceTests.java | 58 ++++++++++++++++++- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index a46880ce23b95..8f67a615d2c0a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -30,7 +30,6 @@ import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -44,13 +43,13 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; -import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -76,10 +75,6 @@ public class GoogleVertexAiService extends SenderService { TaskType.CHAT_COMPLETION ); - private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( - "google vertexai chat completion", - GoogleVertexAiChatCompletionResponseEntity::fromResponse - ); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -90,6 +85,11 @@ public class GoogleVertexAiService extends SenderService { InputType.INTERNAL_SEARCH ); + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION); + } + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } @@ -239,7 +239,6 @@ protected void doUnifiedCompletionInfer( listener.onFailure(createInvalidModelException(model)); return; } - var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model; var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest()); @@ -248,7 +247,6 @@ protected void doUnifiedCompletionInfer( var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); var action = new SenderExecutableAction(getSender(), manager, errorMessage); action.execute(inputs, timeout, listener); - } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 932dfc21e9396..ca093b5cd27fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.inference.services.googlevertexai; - import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; @@ -14,6 +13,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; @@ -26,6 +26,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -41,10 +42,13 @@ import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; +import org.elasticsearch.action.support.PlainActionFuture; import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -63,6 +67,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase { private ThreadPool threadPool; private HttpClientManager clientManager; + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @Before public void init() throws Exception { @@ -78,6 +83,57 @@ public void shutdown() throws IOException { webServer.close(); } + public void testUnifiedCompletionInfer_HappyPath() throws Exception { + // 1. Mock response (array of chunks) + String responseJson = """ + [ + { + "candidates": [ + { "content": { "role": "model", "parts": [ { "text": "This is " } ] } } + ] + }, + { + "candidates": [ + { "content": { "role": "model", "parts": [ { "text": "a test response." } ] } } + ], + "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 4, "totalTokenCount": 9 } + } + ] + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // 2. Setup service and sender factory + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool))) { + + // 3. Create model and input + // Use a dummy JSON string for serviceAccountJson in tests + String dummyServiceAccountJson = + "{\"type\":\"service_account\", \"client_id\": \"1\", \"client_email\": \"test@demo.com\", \"private_key\":\"-----BEGIN PRIVATE KEY-----\\nprivate__key\\n-----END PRIVATE KEY-----\\n\n\", \"private_key_id\":\"1\"}"; + + var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( + "test-project", + "us-central1", + "gemini-pro", + dummyServiceAccountJson, // Pass dummy JSON + null // No specific rate limit settings needed for this test + ); + + var input = UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Say test"), "user", null, null)) + ); + // 4. Call method + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer(model, input, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + // 5. Assertions + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors(); + // TODO: Rename and complete this test. Right now it does not work bc the request requires a valid PCK#8 key. Mock it? Inject + // it? + } + } + public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { var projectId = "project"; var location = "location"; From 2ebfac96ba1f55cfd68a8945e70ee6b995f3bafe Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 6 May 2025 16:20:14 -0300 Subject: [PATCH 12/38] Unified chat completion response and request handlers. Also working with streaming --- ...oogleVertexAiCompletionRequestManager.java | 2 +- .../GoogleVertexAiResponseHandler.java | 12 + ...iUnifiedChatCompletionResponseHandler.java | 183 ++++++++ ...ogleVertexAiUnifiedStreamingProcessor.java | 402 ++++++++++++++++++ ...eVertexAiUnifiedChatCompletionRequest.java | 5 + ...xAiUnifiedChatCompletionRequestEntity.java | 48 ++- 6 files changed, 643 insertions(+), 9 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java index d2fbaddd1db2b..172f5f0f43e5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java @@ -31,7 +31,7 @@ public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiReques private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler(); private static ResponseHandler createGoogleVertexAiResponseHandler() { - return new GoogleVertexAiResponseHandler( + return new GoogleVertexAiUnifiedChatCompletionResponseHandler( "Google Vertex AI chat completion", GoogleVertexAiChatCompletionResponseEntity::fromResponse ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java index 409814feef05c..423d23639bdad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java @@ -9,11 +9,14 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity; +import java.util.function.Function; + import static org.elasticsearch.core.Strings.format; public class GoogleVertexAiResponseHandler extends BaseResponseHandler { @@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse); } + public GoogleVertexAiResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction, + boolean canHandleStreamingResponses + ) { + super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses); + } + @Override protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { if (result.isSuccessfulResponse()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..a519cf2c0f211 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.JsonArrayPartsEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.JsonArrayPartsEventProcessor; + +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.Flow; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { + + private static final String ERROR_FIELD = "error"; + private static final String ERROR_CODE_FIELD = "code"; + private static final String ERROR_MESSAGE_FIELD = "message"; + private static final String ERROR_STATUS_FIELD = "status"; + + public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, GoogleVertexAiErrorResponse::fromResponse, true); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + assert request.isStreaming() : "GoogleVertexAiUnifiedChatCompletionResponseHandler only supports streaming requests"; + + var serverSentEventProcessor = new JsonArrayPartsEventProcessor(new JsonArrayPartsEventParser()); + var googleVertexAiProcessor = new GoogleVertexAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(googleVertexAiProcessor); + return new StreamingUnifiedChatCompletionResults(googleVertexAiProcessor); + } + + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + + return errorResponse instanceof GoogleVertexAiErrorResponse vertexAIErrorResponse + ? new UnifiedChatCompletionException( + restStatus, + errorMessage, + vertexAIErrorResponse.status(), + String.valueOf(vertexAIErrorResponse.code()), + null + ) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + restStatus.name().toLowerCase(Locale.ROOT) + ); + } + + // TODO: This method was auto generated. Check that it's working properly + private static Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = GoogleVertexAiErrorResponse.fromString(message); + if (errorResponse instanceof GoogleVertexAiErrorResponse gver) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + gver.status(), + String.valueOf(gver.code()), + null + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + "stream_error" + ); + } + } + + private static class GoogleVertexAiErrorResponse extends ErrorResponse { + private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class); + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + "google_vertex_ai_error_wrapper", + true, + args -> Optional.ofNullable((GoogleVertexAiErrorResponse) args[0]) + ); + + private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( + "google_vertex_ai_error_body", + true, + args -> new GoogleVertexAiErrorResponse((Integer) args[0], (String) args[1], (String) args[2]) + ); + + static { + ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_CODE_FIELD)); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ERROR_MESSAGE_FIELD)); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_STATUS_FIELD)); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField(ERROR_FIELD) + ); + } + + static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // Log? For now, swallow and return undefined + logger.warn("Failed to parse Google Vertex AI error response body", e); + } + return ErrorResponse.UNDEFINED_ERROR; + } + + static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // Log? For now, swallow and return undefined + logger.warn("Failed to parse Google Vertex AI error string", e); + } + return ErrorResponse.UNDEFINED_ERROR; + } + + private final int code; + @Nullable + private final String status; + + GoogleVertexAiErrorResponse(Integer code, String errorMessage, @Nullable String status) { + super(Objects.requireNonNull(errorMessage)); + this.code = code == null ? 0 : code; + this.status = status; + } + + public int code() { + return code; + } + + @Nullable + public String status() { + return status != null ? status : "google_vertex_ai_error"; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java new file mode 100644 index 0000000000000..1471592d2f612 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java @@ -0,0 +1,402 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.function.BiFunction; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class GoogleVertexAiUnifiedStreamingProcessor extends DelegatingProcessor< + Deque, + StreamingUnifiedChatCompletionResults.Results> { + + private static final Logger logger = LogManager.getLogger(GoogleVertexAiUnifiedStreamingProcessor.class); + + // Response Fields + private static final String CANDIDATES_FIELD = "candidates"; + private static final String CONTENT_FIELD = "content"; + private static final String ROLE_FIELD = "role"; + private static final String PARTS_FIELD = "parts"; + private static final String TEXT_FIELD = "text"; + private static final String FINISH_REASON_FIELD = "finishReason"; + private static final String INDEX_FIELD = "index"; + private static final String USAGE_METADATA_FIELD = "usageMetadata"; + private static final String PROMPT_TOKEN_COUNT_FIELD = "promptTokenCount"; + private static final String CANDIDATES_TOKEN_COUNT_FIELD = "candidatesTokenCount"; + private static final String TOTAL_TOKEN_COUNT_FIELD = "totalTokenCount"; + private static final String ERROR_FIELD = "error"; + private static final String ERROR_CODE_FIELD = "code"; + private static final String ERROR_MESSAGE_FIELD = "message"; + private static final String ERROR_STATUS_FIELD = "status"; + + // Internal representation fields mapping to StreamingUnifiedChatCompletionResults + // Note: Google Vertex AI doesn't provide chunk ID, model, or object per chunk like OpenAI. + // We will construct the Choice.Delta based on the Candidate's content. + + private final BiFunction errorParser; + private final Deque buffer = new LinkedBlockingDeque<>(); + + public GoogleVertexAiUnifiedStreamingProcessor(BiFunction errorParser) { + this.errorParser = errorParser; + } + + @Override + protected void upstreamRequest(long n) { + if (buffer.isEmpty()) { + super.upstreamRequest(n); + } else { + // Drain buffer first + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + + @Override + protected void next(Deque item) throws Exception { + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = new ArrayDeque(item.size()); + + for (var event : item) { + var completionChunk = parse(parserConfig, event); + completionChunk.forEachRemaining(results::offer); + } + + if (results.isEmpty()) { + // Request more if we didn't produce anything + upstream().request(1); + } else if (results.size() == 1) { + // Common case: one event produced one chunk + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // Unlikely for Vertex AI, but handle buffering just in case + logger.warn("Received multiple chunks ({}) from a single SSE batch, buffering.", results.size()); + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); + // If buffer has items, the next upstreamRequest will handle sending them. + } + } + + // TODO: This method is already called with valid Json in event. Maybe we dont need the validation logic, just parse the event + // Leaving this for now but highly guaranteed that this will be removed + private Iterator parse( + XContentParserConfiguration parserConfig, + byte[] event + ) throws IOException { + // Google Vertex AI doesn't have a specific "[DONE]" message like OpenAI. + // The stream ends when the connection closes or a chunk with a final finishReason arrives. + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { + moveToFirstToken(jsonParser); + ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); + + // Check for top-level error first + XContentParser.Token token; + String currentFieldName = null; + while ((token = jsonParser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = jsonParser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + if (ERROR_FIELD.equals(currentFieldName)) { + VertexAiError error = VertexAiErrorParser.parse(jsonParser); + // Map Google's error to ElasticsearchStatusException + // Status mapping might need refinement based on Google's codes + RestStatus status = RestStatus.fromCode(error.code() != 0 ? error.code() : 500); + throw new ElasticsearchStatusException( + "Error from Google Vertex AI: [{}] {}", + status, + error.message() != null ? error.message() : "Unknown error", + status + ); + } else { + // If it's not the error field, parse as a regular chunk + // We need to reset the parser or re-parse, as we consumed tokens. + // Easiest is to re-parse the original data. + try (XContentParser chunkParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { + moveToFirstToken(chunkParser); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse( + chunkParser + ); + // If parsing succeeds but yields no candidates (e.g., empty response), return empty. + if (chunk.choices() == null || chunk.choices().isEmpty()) { + return Collections.emptyIterator(); + } + return Collections.singleton(chunk).iterator(); + } + } + } else { + // Ignore other top-level fields if any, besides "error" and the main structure + jsonParser.skipChildren(); + } + } + // If we reach here, it means the object was parsed but didn't match the error structure + // and didn't trigger the re-parse logic (e.g., empty object {}). Re-parse to be sure. + try (XContentParser chunkParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { + moveToFirstToken(chunkParser); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse( + chunkParser + ); + // If parsing succeeds but yields no candidates (e.g., empty response), return empty. + if (chunk.choices() == null || chunk.choices().isEmpty()) { + return Collections.emptyIterator(); + } + return Collections.singleton(chunk).iterator(); + } + } + } + + // Helper class to represent Google Vertex AI error structure + private record VertexAiError(int code, String message, String status) {} + + private static class VertexAiErrorParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ERROR_FIELD, + true, + args -> new VertexAiError((int) args[0], (String) args[1], (String) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_CODE_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_MESSAGE_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_STATUS_FIELD)); + // Ignore unknown fields + } + + public static VertexAiError parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + // Main parser for the chunk structure + private static class GoogleVertexAiChatCompletionChunkParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "google_vertexai_chat_completion_chunk", + true, + // Args: candidates, usageMetadata + args -> { + List candidates = (List) args[0]; + UsageMetadata usage = (UsageMetadata) args[1]; + + if (candidates == null || candidates.isEmpty()) { + // If there are no candidates, but usage info exists, create a chunk just for usage. + if (usage != null) { + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + null, // No ID from Vertex AI + Collections.emptyList(), + null, // No model per chunk + null, // No object per chunk + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage( + usage.candidatesTokenCount(), + usage.promptTokenCount(), + usage.totalTokenCount() + ) + ); + } + // Return a mostly empty chunk if no candidates and no usage + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + null, + Collections.emptyList(), + null, + null, + null + ); + } + + // Map candidates to choices + List choices = candidates.stream().map(candidate -> { + String contentText = null; + String role = null; + if (candidate.content() != null + && candidate.content().parts() != null + && candidate.content().parts().isEmpty() == false) { + // Assuming only one part with text for now + contentText = candidate.content().parts().get(0).text(); + role = candidate.content().role(); + } + + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + contentText, + null, // No refusal field in Vertex AI + role, + null // TODO: Handle tool/function calls if they appear in streaming + ); + + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + delta, + candidate.finishReason(), + candidate.index() + ); + }).toList(); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage usageResult = null; + if (usage != null) { + usageResult = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage( + usage.candidatesTokenCount(), + usage.promptTokenCount(), + usage.totalTokenCount() + ); + } + + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + null, // No ID from Vertex AI + choices, + null, // No model per chunk + null, // No object per chunk + usageResult + ); + } + ); + + static { + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), // Candidates might be absent + (p, c) -> CandidateParser.parse(p), + new ParseField(CANDIDATES_FIELD) + ); + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), // Usage might be absent until the end + (p, c) -> UsageMetadataParser.parse(p), + new ParseField(USAGE_METADATA_FIELD) + ); + // Ignore other top-level fields like safetyRatings, citationMetadata etc. + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + // --- Nested Parsers for Google Vertex AI structure --- + + private record Candidate(Content content, String finishReason, int index) {} + + private static class CandidateParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "candidate", + true, + args -> new Candidate((Content) args[0], (String) args[1], args[2] == null ? 0 : (int) args[2]) // index might be null + ); + + static { + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ContentParser.parse(p), + new ParseField(CONTENT_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(INDEX_FIELD)); + // Ignore safetyRatings, citationMetadata, etc. + } + + public static Candidate parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record Content(String role, List parts) {} + + private static class ContentParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + CONTENT_FIELD, + true, + args -> new Content((String) args[0], (List) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> PartParser.parse(p), + new ParseField(PARTS_FIELD) + ); + } + + public static Content parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record Part(String text) {} // Assuming only text parts for now + + private static class PartParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "part", + true, + args -> new Part((String) args[0]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TEXT_FIELD)); + // Ignore other part types like functionCall, functionResponse, fileData, etc. for now + } + + public static Part parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private record UsageMetadata(int promptTokenCount, int candidatesTokenCount, int totalTokenCount) {} + + private static class UsageMetadataParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + USAGE_METADATA_FIELD, + true, + args -> new UsageMetadata( + args[0] == null ? 0 : (int) args[0], + args[1] == null ? 0 : (int) args[1], + args[2] == null ? 0 : (int) args[2] + ) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(PROMPT_TOKEN_COUNT_FIELD)); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CANDIDATES_TOKEN_COUNT_FIELD)); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TOTAL_TOKEN_COUNT_FIELD)); + } + + public static UsageMetadata parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + // Helper to wrap a single chunk in a Deque + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(1); + deque.offer(result); + return deque; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java index 48b2a5a4e0683..6541569887c1a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -67,6 +67,11 @@ public boolean[] getTruncationInfo() { return null; } + @Override + public boolean isStreaming() { + return unifiedChatInput.stream(); + } + @Override public String getInferenceEntityId() { return model.getInferenceEntityId(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index ddaa8ca91b4c4..edb733396a9f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.request; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -31,7 +33,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont // TODO: Add other generationConfig fields if needed (e.g., stopSequences, topK) private final UnifiedChatInput unifiedChatInput; - private final GoogleVertexAiChatCompletionModel model; + private final GoogleVertexAiChatCompletionModel model; // TODO: This is not being used? private static final String USER_ROLE = "user"; private static final String MODEL_ROLE = "model"; @@ -49,15 +51,33 @@ private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) thro return messageRoleLowered; } - // TODO: Here is OK to throw an IOException? - throw new IOException( + var errorMessage = format( "Role %s not supported by Google VertexAI ChatCompletion. Supported roles: '%s', '%s'", messageRole, USER_ROLE, MODEL_ROLE - ) - ); + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + private void validateAndAddContentObjectsToBuilder(XContentBuilder builder, UnifiedCompletionRequest.ContentObjects contentObjects) + throws IOException { + + for (var contentObject : contentObjects.contentObjects()) { + if (contentObject.type().equals(TEXT) == false) { + var errorMessage = format( + "Type %s not supported by Google VertexAI ChatCompletion. Supported types: 'text'", + contentObject.type() + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + // We are only supporting Text messages but VertexAI supports more types: + // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/Content?_gl=1*q4uxnh*_up*MQ..&gclid=CjwKCAjwwqfABhBcEiwAZJjC3uBQNP9KUMZX8AGXvFXP2rIEQSfCX9RLP5gjzx5r-4xz1daBSxM7GBoCY64QAvD_BwE&gclsrc=aw.ds#Part + builder.startObject(); + builder.field(TEXT, contentObject.text()); + builder.endObject(); + } } @@ -69,9 +89,21 @@ private void buildContents(XContentBuilder builder) throws IOException { builder.startObject(); builder.field(ROLE, messageRoleToGoogleVertexAiSupportedRole(message.role())); builder.startArray(PARTS); - builder.startObject(); - builder.field(TEXT, message.content().toString()); - builder.endObject(); + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> { + builder.startObject(); + builder.field(TEXT, contentString.content()); + builder.endObject(); + } + case UnifiedCompletionRequest.ContentObjects contentObjects -> validateAndAddContentObjectsToBuilder( + builder, + contentObjects + ); + case null -> { + var errorMessage = "Google VertexAI API requires at least one text message but none were provided"; + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + } builder.endArray(); builder.endObject(); } From 679ea8078989e49b7a0310b07a0c13b7de7bad64 Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 8 May 2025 14:54:04 -0300 Subject: [PATCH 13/38] StreamingProcessor now support tools. Added more tests --- ...ogleVertexAiUnifiedStreamingProcessor.java | 312 ++++++++---------- ...ertexAiUnifiedStreamingProcessorTests.java | 88 +++++ 2 files changed, 232 insertions(+), 168 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java index 1471592d2f612..b2c33cb294a08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java @@ -9,9 +9,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; @@ -23,10 +24,13 @@ import java.io.IOException; import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.concurrent.LinkedBlockingDeque; import java.util.function.BiFunction; @@ -39,7 +43,6 @@ public class GoogleVertexAiUnifiedStreamingProcessor extends DelegatingProcessor private static final Logger logger = LogManager.getLogger(GoogleVertexAiUnifiedStreamingProcessor.class); - // Response Fields private static final String CANDIDATES_FIELD = "candidates"; private static final String CONTENT_FIELD = "content"; private static final String ROLE_FIELD = "role"; @@ -51,14 +54,13 @@ public class GoogleVertexAiUnifiedStreamingProcessor extends DelegatingProcessor private static final String PROMPT_TOKEN_COUNT_FIELD = "promptTokenCount"; private static final String CANDIDATES_TOKEN_COUNT_FIELD = "candidatesTokenCount"; private static final String TOTAL_TOKEN_COUNT_FIELD = "totalTokenCount"; - private static final String ERROR_FIELD = "error"; - private static final String ERROR_CODE_FIELD = "code"; - private static final String ERROR_MESSAGE_FIELD = "message"; - private static final String ERROR_STATUS_FIELD = "status"; + private static final String MODEL_VERSION_FIELD = "modelVersion"; + private static final String RESPONSE_ID_FIELD = "responseId"; + private static final String FUNCTION_CALL_FIELD = "functionCall"; + private static final String FUNCTION_NAME_FIELD = "name"; + private static final String FUNCTION_ARGS_FIELD = "args"; - // Internal representation fields mapping to StreamingUnifiedChatCompletionResults - // Note: Google Vertex AI doesn't provide chunk ID, model, or object per chunk like OpenAI. - // We will construct the Choice.Delta based on the Candidate's content. + private static final String FUNCTION_TYPE = "function"; private final BiFunction errorParser; private final Deque buffer = new LinkedBlockingDeque<>(); @@ -72,224 +74,157 @@ protected void upstreamRequest(long n) { if (buffer.isEmpty()) { super.upstreamRequest(n); } else { - // Drain buffer first downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); } } @Override - protected void next(Deque item) throws Exception { + protected void next(Deque events) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - var results = new ArrayDeque(item.size()); - - for (var event : item) { - var completionChunk = parse(parserConfig, event); - completionChunk.forEachRemaining(results::offer); + var results = new ArrayDeque(events.size()); + + for (var event : events) { + try { + var completionChunk = parse(parserConfig, event); + completionChunk.forEachRemaining(results::offer); + } catch (Exception e) { + var eventString = Arrays.toString(event); + logger.warn("Failed to parse event from Google Vertex AI provider: {}", eventString); + throw errorParser.apply(eventString, e); + } } if (results.isEmpty()) { - // Request more if we didn't produce anything upstream().request(1); } else if (results.size() == 1) { - // Common case: one event produced one chunk downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); } else { - // Unlikely for Vertex AI, but handle buffering just in case + // Vertex AI doesn't specify how many events per chunk, so handle buffering just in case logger.warn("Received multiple chunks ({}) from a single SSE batch, buffering.", results.size()); var firstItem = singleItem(results.poll()); while (results.isEmpty() == false) { buffer.offer(results.poll()); } downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); - // If buffer has items, the next upstreamRequest will handle sending them. } } - // TODO: This method is already called with valid Json in event. Maybe we dont need the validation logic, just parse the event - // Leaving this for now but highly guaranteed that this will be removed private Iterator parse( XContentParserConfiguration parserConfig, byte[] event ) throws IOException { - // Google Vertex AI doesn't have a specific "[DONE]" message like OpenAI. - // The stream ends when the connection closes or a chunk with a final finishReason arrives. - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { moveToFirstToken(jsonParser); ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); - // Check for top-level error first - XContentParser.Token token; - String currentFieldName = null; - while ((token = jsonParser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = jsonParser.currentName(); - } else if (token == XContentParser.Token.START_OBJECT) { - if (ERROR_FIELD.equals(currentFieldName)) { - VertexAiError error = VertexAiErrorParser.parse(jsonParser); - // Map Google's error to ElasticsearchStatusException - // Status mapping might need refinement based on Google's codes - RestStatus status = RestStatus.fromCode(error.code() != 0 ? error.code() : 500); - throw new ElasticsearchStatusException( - "Error from Google Vertex AI: [{}] {}", - status, - error.message() != null ? error.message() : "Unknown error", - status + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse(jsonParser); + return Collections.singleton(chunk).iterator(); + } + } + + public static class GoogleVertexAiChatCompletionChunkParser { + private static @Nullable StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage usageMetadataToChunk( + @Nullable UsageMetadata usage + ) { + if (usage == null) { + return null; + } + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage( + usage.candidatesTokenCount(), + usage.promptTokenCount(), + usage.totalTokenCount() + ); + } + + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice candidateToChoice(Candidate candidate) { + StringBuilder contentTextBuilder = new StringBuilder(); + List toolCalls = new ArrayList<>(); + + String role = null; + + var contentAndPartsAreNotEmpty = candidate.content() != null + && candidate.content().parts() != null + && candidate.content().parts().isEmpty() == false; + + if (contentAndPartsAreNotEmpty) { + role = candidate.content().role(); // Role is at the content level + for (Part part : candidate.content().parts()) { + if (part.text() != null) { + contentTextBuilder.append(part.text()); + } + if (part.functionCall() != null) { + FunctionCall fc = part.functionCall(); + var function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + fc.args(), + fc.name() + ); + toolCalls.add( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 0, // No explicit ID from VertexAI so we use 0 + function.name(), // VertexAI does not provide an id for the function call so we use the name + function, + FUNCTION_TYPE + ) ); - } else { - // If it's not the error field, parse as a regular chunk - // We need to reset the parser or re-parse, as we consumed tokens. - // Easiest is to re-parse the original data. - try (XContentParser chunkParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { - moveToFirstToken(chunkParser); - StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse( - chunkParser - ); - // If parsing succeeds but yields no candidates (e.g., empty response), return empty. - if (chunk.choices() == null || chunk.choices().isEmpty()) { - return Collections.emptyIterator(); - } - return Collections.singleton(chunk).iterator(); - } } - } else { - // Ignore other top-level fields if any, besides "error" and the main structure - jsonParser.skipChildren(); - } - } - // If we reach here, it means the object was parsed but didn't match the error structure - // and didn't trigger the re-parse logic (e.g., empty object {}). Re-parse to be sure. - try (XContentParser chunkParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event)) { - moveToFirstToken(chunkParser); - StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse( - chunkParser - ); - // If parsing succeeds but yields no candidates (e.g., empty response), return empty. - if (chunk.choices() == null || chunk.choices().isEmpty()) { - return Collections.emptyIterator(); } - return Collections.singleton(chunk).iterator(); } - } - } - // Helper class to represent Google Vertex AI error structure - private record VertexAiError(int code, String message, String status) {} + List finalToolCalls = toolCalls.isEmpty() + ? null + : toolCalls; - private static class VertexAiErrorParser { - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - ERROR_FIELD, - true, - args -> new VertexAiError((int) args[0], (String) args[1], (String) args[2]) - ); - - static { - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_CODE_FIELD)); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_MESSAGE_FIELD)); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ERROR_STATUS_FIELD)); - // Ignore unknown fields - } + var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + contentTextBuilder.isEmpty() ? null : contentTextBuilder.toString(), + null, + role, + finalToolCalls + ); - public static VertexAiError parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, candidate.finishReason(), candidate.index()); } - } - - // Main parser for the chunk structure - private static class GoogleVertexAiChatCompletionChunkParser { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "google_vertexai_chat_completion_chunk", true, - // Args: candidates, usageMetadata args -> { List candidates = (List) args[0]; UsageMetadata usage = (UsageMetadata) args[1]; + String modelversion = (String) args[2]; + String responseId = (String) args[3]; - if (candidates == null || candidates.isEmpty()) { - // If there are no candidates, but usage info exists, create a chunk just for usage. - if (usage != null) { - return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - null, // No ID from Vertex AI - Collections.emptyList(), - null, // No model per chunk - null, // No object per chunk - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage( - usage.candidatesTokenCount(), - usage.promptTokenCount(), - usage.totalTokenCount() - ) - ); - } - // Return a mostly empty chunk if no candidates and no usage - return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - null, - Collections.emptyList(), - null, - null, - null - ); - } - // Map candidates to choices - List choices = candidates.stream().map(candidate -> { - String contentText = null; - String role = null; - if (candidate.content() != null - && candidate.content().parts() != null - && candidate.content().parts().isEmpty() == false) { - // Assuming only one part with text for now - contentText = candidate.content().parts().get(0).text(); - role = candidate.content().role(); - } - - var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( - contentText, - null, // No refusal field in Vertex AI - role, - null // TODO: Handle tool/function calls if they appear in streaming - ); + boolean candidatesIsEmpty = candidates == null || candidates.isEmpty(); + List choices = candidatesIsEmpty + ? Collections.emptyList() + : candidates.stream().map(GoogleVertexAiChatCompletionChunkParser::candidateToChoice).toList(); - return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( - delta, - candidate.finishReason(), - candidate.index() - ); - }).toList(); - - StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage usageResult = null; - if (usage != null) { - usageResult = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage( - usage.candidatesTokenCount(), - usage.promptTokenCount(), - usage.totalTokenCount() - ); - } return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - null, // No ID from Vertex AI + responseId, choices, - null, // No model per chunk - null, // No object per chunk - usageResult + modelversion, + null, + usageMetadataToChunk(usage) ); } ); static { PARSER.declareObjectArray( - ConstructingObjectParser.optionalConstructorArg(), // Candidates might be absent + ConstructingObjectParser.optionalConstructorArg(), (p, c) -> CandidateParser.parse(p), new ParseField(CANDIDATES_FIELD) ); PARSER.declareObject( - ConstructingObjectParser.optionalConstructorArg(), // Usage might be absent until the end + ConstructingObjectParser.optionalConstructorArg(), (p, c) -> UsageMetadataParser.parse(p), new ParseField(USAGE_METADATA_FIELD) ); - // Ignore other top-level fields like safetyRatings, citationMetadata etc. + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(MODEL_VERSION_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(RESPONSE_ID_FIELD)); } public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { @@ -305,7 +240,12 @@ private static class CandidateParser { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "candidate", true, - args -> new Candidate((Content) args[0], (String) args[1], args[2] == null ? 0 : (int) args[2]) // index might be null + args -> { + var content = (Content) args[0]; + var finishReason = (String) args[1]; + var index = args[2] == null ? 0 : (int) args[2]; + return new Candidate(content, finishReason, index); + } ); static { @@ -316,7 +256,6 @@ private static class CandidateParser { ); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(INDEX_FIELD)); - // Ignore safetyRatings, citationMetadata, etc. } public static Candidate parse(XContentParser parser) throws IOException { @@ -348,18 +287,22 @@ public static Content parse(XContentParser parser) throws IOException { } } - private record Part(String text) {} // Assuming only text parts for now + private record Part(@Nullable String text, @Nullable FunctionCall functionCall) {} // Modified private static class PartParser { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "part", true, - args -> new Part((String) args[0]) + args -> new Part((String) args[0], (FunctionCall) args[1]) ); static { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TEXT_FIELD)); - // Ignore other part types like functionCall, functionResponse, fileData, etc. for now + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> FunctionCallParser.parse(p), + new ParseField(FUNCTION_CALL_FIELD) + ); } public static Part parse(XContentParser parser) throws IOException { @@ -367,6 +310,40 @@ public static Part parse(XContentParser parser) throws IOException { } } + private record FunctionCall(String name, String args) {} + + private static class FunctionCallParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + FUNCTION_CALL_FIELD, + true, + args -> { + var name = (String) args[0]; + + @SuppressWarnings("unchecked") + var argsMap = (Map) args[1]; + if (argsMap == null) { + return new FunctionCall(name, null); + } + try { + var builder = XContentFactory.jsonBuilder().map(argsMap); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + return new FunctionCall(name, json); + } catch (IOException e) { + logger.warn("Failed to parse and convert VertexAI function args to json", e); + return new FunctionCall(name, null); + } + } + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(FUNCTION_NAME_FIELD)); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> p.map(), new ParseField(FUNCTION_ARGS_FIELD)); + } + + public static FunctionCall parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } private record UsageMetadata(int promptTokenCount, int candidatesTokenCount, int totalTokenCount) {} private static class UsageMetadataParser { @@ -391,7 +368,6 @@ public static UsageMetadata parse(XContentParser parser) throws IOException { } } - // Helper to wrap a single chunk in a Deque private Deque singleItem( StreamingUnifiedChatCompletionResults.ChatCompletionChunk result ) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java new file mode 100644 index 0000000000000..069f96b74c4da --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; + +public class GoogleVertexAiUnifiedStreamingProcessorTests extends ESTestCase { + + public void testJsonLiteral() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ + { "text" : "Elastic" }, + { + "functionCall": { + "name": "getWeatherData", + "args": { "unit": "celsius", "location": "buenos aires, argentina" } + } + } + ] + }, + "finishReason": "MAXTOKENS" + } ], + "usageMetadata" : { + "trafficType" : "ON_DEMAND" + }, + "modelVersion" : "gemini-2.0-flash-lite", + "createTime" : "2025-05-07T14:36:16.122336Z", + "responseId" : "responseId" + } + """; + + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("responseId", chunk.id()); + assertEquals(1, chunk.choices().size()); + + var choice = chunk.choices().getFirst(); + assertEquals("Elastic", choice.delta().content()); + assertEquals("model", choice.delta().role()); + assertEquals("gemini-2.0-flash-lite", chunk.model()); + assertEquals("MAXTOKENS", choice.finishReason()); + assertEquals(0, choice.index()); // VertexAI response does not have Index. Use 0 as default + + assertEquals(1, choice.delta().toolCalls().size()); + var toolCall = choice.delta().toolCalls().getFirst(); + assertEquals("getWeatherData", toolCall.function().name()); + assertEquals("{\"unit\":\"celsius\",\"location\":\"buenos aires, argentina\"}", toolCall.function().arguments()); + + } catch (IOException e) { + fail(); + } + } + + public void testJsonError() { + String json = String.format(""" + { + "error": { + "code": 400, + "message": "Invalid JSON payload received. Expected an object key or }.\\n# Changed tool name \\n^", + "status": "INVALID_ARGUMENT" + } + } + """); + fail("Test not implemented"); + + } +} From e611cc3fd7464c4f436a2e994cfe9129943374ec Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 8 May 2025 15:33:23 -0300 Subject: [PATCH 14/38] More tests for streaming processor --- ...ertexAiUnifiedStreamingProcessorTests.java | 134 ++++++++++++++++-- 1 file changed, 124 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java index 069f96b74c4da..7bf5c5f2df30c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java @@ -37,6 +37,9 @@ public void testJsonLiteral() { "finishReason": "MAXTOKENS" } ], "usageMetadata" : { + "promptTokenCount": 10, + "candidatesTokenCount": 20, + "totalTokenCount": 30, "trafficType" : "ON_DEMAND" }, "modelVersion" : "gemini-2.0-flash-lite", @@ -67,22 +70,133 @@ public void testJsonLiteral() { assertEquals("getWeatherData", toolCall.function().name()); assertEquals("{\"unit\":\"celsius\",\"location\":\"buenos aires, argentina\"}", toolCall.function().arguments()); + assertNotNull(chunk.usage()); + assertEquals(20, chunk.usage().completionTokens()); + assertEquals(10, chunk.usage().promptTokens()); + assertEquals(30, chunk.usage().totalTokens()); + + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } + } + + public void testJsonLiteral_optionalTopLevelFieldsMissing() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ { "text" : "Hello" } ] + }, + "finishReason": "STOP" + } ] + } + """; + + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertNull(chunk.id()); + assertEquals(1, chunk.choices().size()); + var choice = chunk.choices().getFirst(); + assertEquals("Hello", choice.delta().content()); + assertEquals("model", choice.delta().role()); + assertNull(chunk.model()); + assertEquals("STOP", choice.finishReason()); + assertEquals(0, choice.index()); + assertNull(choice.delta().toolCalls()); + assertNull(chunk.usage()); + + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } + } + + public void testJsonLiteral_functionCallArgsMissing() { + String json = """ + { + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ + { + "functionCall": { + "name": "getLocation" + } + } + ] + } + } ], + "responseId" : "resId789" + } + """; + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("resId789", chunk.id()); + assertEquals(1, chunk.choices().size()); + var choice = chunk.choices().getFirst(); + assertEquals("model", choice.delta().role()); + assertNull(choice.delta().content()); + + assertNotNull(choice.delta().toolCalls()); + assertEquals(1, choice.delta().toolCalls().size()); + var toolCall = choice.delta().toolCalls().getFirst(); + assertEquals("getLocation", toolCall.function().name()); + assertNull(toolCall.function().arguments()); + } catch (IOException e) { - fail(); + fail("IOException during test: " + e.getMessage()); } } - public void testJsonError() { - String json = String.format(""" + public void testJsonLiteral_multipleTextParts() { + String json = """ { - "error": { - "code": 400, - "message": "Invalid JSON payload received. Expected an object key or }.\\n# Changed tool name \\n^", - "status": "INVALID_ARGUMENT" - } + "candidates" : [ { + "content" : { + "role" : "model", + "parts" : [ + { "text" : "This is the first part. " }, + { "text" : "This is the second part." } + ] + }, + "finishReason": "STOP" + } ], + "responseId" : "multiTextId" } - """); - fail("Test not implemented"); + """; + + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + + assertEquals("multiTextId", chunk.id()); + assertEquals(1, chunk.choices().size()); + + var choice = chunk.choices().getFirst(); + assertEquals("model", choice.delta().role()); + // Verify that the text from multiple parts is concatenated + assertEquals("This is the first part. This is the second part.", choice.delta().content()); + assertEquals("STOP", choice.finishReason()); + assertEquals(0, choice.index()); + assertNull(choice.delta().toolCalls()); // No function calls in this test case + assertNull(chunk.model()); + assertNull(chunk.usage()); + + } catch (IOException e) { + fail("IOException during test: " + e.getMessage()); + } } } From 87e428a00cdd3c0dc71c8bd50c0d3c1e3189daa9 Mon Sep 17 00:00:00 2001 From: lhoet Date: Mon, 12 May 2025 11:42:55 -0300 Subject: [PATCH 15/38] Request entity tests --- ...xAiUnifiedChatCompletionRequestEntity.java | 4 +- ...ifiedChatCompletionRequestEntityTests.java | 354 ++++++++++++++++++ 2 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index edb733396a9f1..dc260dcb87258 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -53,7 +53,7 @@ private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) thro var errorMessage = format( - "Role %s not supported by Google VertexAI ChatCompletion. Supported roles: '%s', '%s'", + "Role [%s] not supported by Google VertexAI ChatCompletion. Supported roles: [%s, %s]", messageRole, USER_ROLE, MODEL_ROLE @@ -67,7 +67,7 @@ private void validateAndAddContentObjectsToBuilder(XContentBuilder builder, Unif for (var contentObject : contentObjects.contentObjects()) { if (contentObject.type().equals(TEXT) == false) { var errorMessage = format( - "Type %s not supported by Google VertexAI ChatCompletion. Supported types: 'text'", + "Type [%s] not supported by Google VertexAI ChatCompletion. Supported types: [text]", contentObject.type() ); throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..3ee48bac9ed78 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,354 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.apache.commons.lang3.NotImplementedException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.hamcrest.Matchers.containsString; + +public class GoogleVertexAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + private static final String USER_ROLE = "user"; + private static final String MODEL_ROLE = "model"; + + private GoogleVertexAiChatCompletionModel createModel() { + // The actual values here don't matter for serialization logic, + // as the model isn't directly used for generating the request body fields in this entity. + return GoogleVertexAiChatCompletionModelTests.createCompletionModel("projectID", "location", "modelId", "modelName", null); + } + + public void testBasicSerialization_SingleMessage() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, Vertex AI!"), + USER_ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); // stream doesn't affect VertexAI request body + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello, Vertex AI!" + } + ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_MultipleMessages() throws IOException { + var messages = List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Previous user message."), + USER_ROLE, + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Previous model response."), + MODEL_ROLE, + null, + null + ), + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Current user query."), USER_ROLE, null, null) + ); + + var unifiedRequest = UnifiedCompletionRequest.of(messages); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Previous user message." } ] + }, + { + "role": "model", + "parts": [ { "text": "Previous model response." } ] + }, + { + "role": "user", + "parts": [ { "text": "Current user query." } ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_WithAllGenerationConfig() throws IOException { + List messages = List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello Gemini!"), USER_ROLE, null, null) + ); + var completionRequestWithGenerationConfig = new UnifiedCompletionRequest( + messages, + "modelId", + 100L, + List.of("stop1", "stop2"), + 0.5f, + null, + null, + 0.9F + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Hello Gemini!" } ] + } + ], + "generationConfig": { + "stopSequences": ["stop1", "stop2"], + "temperature": 0.5, + "maxOutputTokens": 100, + "topP": 0.9 + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_WithSomeGenerationConfig() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Partial config."), + USER_ROLE, + null, + null + ); + var completionRequestWithGenerationConfig = new UnifiedCompletionRequest( + List.of(message), + "modelId", + 50L, + null, + 0.7f, + null, + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Partial config." } ] + } + ], + "generationConfig": { + "temperature": 0.7, + "maxOutputTokens": 50 + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_NoGenerationConfig() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("No extra config."), + USER_ROLE, + null, + null + ); + // No generation config fields set on unifiedRequest + var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "No extra config." } ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_WithContentObjects() throws IOException { + var contentObjects = List.of( + new UnifiedCompletionRequest.ContentObject("First part. ", "text"), + new UnifiedCompletionRequest.ContentObject("Second part.", "text") + ); + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(contentObjects), + USER_ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { "text": "First part. " }, + { "text": "Second part." } + ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testError_UnsupportedRole() throws IOException { + var unsupportedRole = "system"; + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Test"), + unsupportedRole, + null, + null + ); + var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + assertEquals(RestStatus.BAD_REQUEST, statusException.status()); + assertThat(statusException.toString(), containsString("Role [system] not supported by Google VertexAI ChatCompletion")); + } + + public void testError_UnsupportedContentObjectType() throws IOException { + var contentObjects = List.of(new UnifiedCompletionRequest.ContentObject("http://example.com/image.png", "image_url")); + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(contentObjects), + USER_ROLE, + null, + null + ); + var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + assertEquals(RestStatus.BAD_REQUEST, statusException.status()); + assertThat(statusException.toString(), containsString("Type [image_url] not supported by Google VertexAI ChatCompletion")); + + } + +} From 193d06d88ed2317f461e22d0f3bfbb50d725ad38 Mon Sep 17 00:00:00 2001 From: lhoet Date: Mon, 12 May 2025 13:54:01 -0300 Subject: [PATCH 16/38] Google vertexai unified chat completion entity now accepting tools and tools choice with tests --- ...xAiUnifiedChatCompletionRequestEntity.java | 95 ++++++- ...ifiedChatCompletionRequestEntityTests.java | 235 +++++++++++++++++- 2 files changed, 324 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index dc260dcb87258..eeb7517773958 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.request; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; @@ -30,7 +31,18 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String TEMPERATURE = "temperature"; private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens"; private static final String TOP_P = "topP"; - // TODO: Add other generationConfig fields if needed (e.g., stopSequences, topK) + + private static final String TOOLS = "tools"; + private static final String FUNCTION_DECLARATIONS = "functionDeclarations"; + private static final String FUNCTION_NAME = "name"; + private static final String FUNCTION_DESCRIPTION = "description"; + private static final String FUNCTION_PARAMETERS = "parameters"; + private static final String FUNCTION_TYPE = "function"; + private static final String TOOL_CONFIG = "toolConfig"; + private static final String FUNCTION_CALLING_CONFIG = "functionCallingConfig"; + private static final String TOOL_MODE = "mode"; + private static final String TOOL_MODE_ANY = "ANY"; + private static final String ALLOWED_FUNCTION_NAMES = "allowedFunctionNames"; private final UnifiedChatInput unifiedChatInput; private final GoogleVertexAiChatCompletionModel model; // TODO: This is not being used? @@ -110,6 +122,85 @@ private void buildContents(XContentBuilder builder) throws IOException { builder.endArray(); } + private void buildTools(XContentBuilder builder) throws IOException { + var request = unifiedChatInput.getRequest(); + + var tools = request.tools(); + if (tools == null || tools.isEmpty()) { + return; + } + + builder.startArray(TOOLS); + for (var tool : tools) { + if (FUNCTION_TYPE.equals(tool.type()) == false) { + var errorMessage = format( + "Tool type [%s] not supported by Google VertexAI ChatCompletion. Supported types: [%s]", + tool.type(), + FUNCTION_TYPE + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + var function = tool.function(); + if (function == null) { + var errorMessage = format("Tool of type [%s] must have a function definition", tool.type()); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + builder.startObject(); + builder.startArray(FUNCTION_DECLARATIONS); + builder.startObject(); + + builder.field(FUNCTION_NAME, function.name()); + if (Strings.hasText(function.description())) { + builder.field(FUNCTION_DESCRIPTION, function.description()); + } + + if (function.parameters() != null && function.parameters().isEmpty() == false) { + builder.field(FUNCTION_PARAMETERS, function.parameters()); + } + + builder.endObject(); + builder.endArray(); + builder.endObject(); + } + builder.endArray(); + } + + private void buildToolConfig(XContentBuilder builder) throws IOException { + // Build the "tool_config" object (function calling config) + var request = unifiedChatInput.getRequest(); + UnifiedCompletionRequest.ToolChoiceObject toolChoice = (UnifiedCompletionRequest.ToolChoiceObject) request.toolChoice(); + if (toolChoice == null) { + return; + } + if (FUNCTION_TYPE.equals(toolChoice.type()) == false) { + var errorMessage = format( + "Tool choice type [%s] not supported by Google VertexAI ChatCompletion. Supported types: [%s]", + toolChoice.type(), + FUNCTION_TYPE + ); + throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + } + + builder.startObject(TOOL_CONFIG); + builder.startObject(FUNCTION_CALLING_CONFIG); + + var chosenFunction = toolChoice.function(); + if (chosenFunction != null) { + // If we are using toolChoice we set the API to use the 'ANY', meaning that the model will call this tool + // We do that since it's the only supported way right now to make compatible the OpenAi spec with VertexAI spec + builder.field(TOOL_MODE, TOOL_MODE_ANY); + if (Strings.hasText(chosenFunction.name())) { + builder.startArray(ALLOWED_FUNCTION_NAMES); + builder.value(chosenFunction.name()); + builder.endArray(); + } + + builder.endObject(); + builder.endObject(); + } + } + private void buildGenerationConfig(XContentBuilder builder) throws IOException { var request = unifiedChatInput.getRequest(); @@ -146,6 +237,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws buildContents(builder); buildGenerationConfig(builder); + buildTools(builder); + buildToolConfig(builder); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index 3ee48bac9ed78..827ed294b81e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -7,10 +7,8 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.request; -import org.apache.commons.lang3.NotImplementedException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -20,15 +18,15 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; -import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; public class GoogleVertexAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { @@ -131,6 +129,149 @@ public void testSerialization_MultipleMessages() throws IOException { assertJsonEquals(jsonString, expectedJson); } + public void testSerialization_Tools() throws IOException { + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + null, + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object", "description", "a description"), + null + ) + ) + ), + null + ); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "some text" } ] + } + ], + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "description": "a description" + } + } + ] + } + ] + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerialization_ToolsChoice() throws IOException { + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object", "description", "a description"), + null + ) + ) + ), + null + ); + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); + GoogleVertexAiChatCompletionModel model = createModel(); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "some text" } ] + } + ], + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "description": "a description" + } + } + ] + } + ], + "toolConfig": { + "functionCallingConfig" : { + "mode": "ANY", + "allowedFunctionNames": [ "some function" ] + } + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + public void testSerialization_WithAllGenerationConfig() throws IOException { List messages = List.of( new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello Gemini!"), USER_ROLE, null, null) @@ -348,7 +489,91 @@ public void testError_UnsupportedContentObjectType() throws IOException { assertEquals(RestStatus.BAD_REQUEST, statusException.status()); assertThat(statusException.toString(), containsString("Type [image_url] not supported by Google VertexAI ChatCompletion")); - } + public void testParseAllFields() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "some text" + } + ] + } + ], + "generationConfig": { + "stopSequences": ["stop"], + "temperature": 0.1, + "maxOutputTokens": 100, + "topP": 0.2 + }, + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{'order_id': 'order_12345'}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + 100L, + List.of("stop"), + 0.1F, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + 0.2F + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } } From 813a2e85bd4bc344d6eafe5b1099ebebf7eb5950 Mon Sep 17 00:00:00 2001 From: lhoet Date: Mon, 12 May 2025 14:47:35 -0300 Subject: [PATCH 17/38] Serializing function call message --- ...xAiUnifiedChatCompletionRequestEntity.java | 41 ++++++++++++++++++- .../GoogleVertexAiServiceTests.java | 6 +++ ...ifiedChatCompletionRequestEntityTests.java | 19 +++++++-- 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index eeb7517773958..3d89b49bba87b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -9,20 +9,25 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import java.io.IOException; +import java.util.Map; import java.util.Objects; import static org.elasticsearch.core.Strings.format; public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXContentObject { - // Field names matching the Google Vertex AI API structure private static final String CONTENTS = "contents"; private static final String ROLE = "role"; private static final String PARTS = "parts"; @@ -44,6 +49,10 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String TOOL_MODE_ANY = "ANY"; private static final String ALLOWED_FUNCTION_NAMES = "allowedFunctionNames"; + private static final String FUNCTION_CALL = "functionCall"; + private static final String FUNCTION_CALL_NAME = "name"; + private static final String FUNCTION_CALL_ARGS = "args"; + private final UnifiedChatInput unifiedChatInput; private final GoogleVertexAiChatCompletionModel model; // TODO: This is not being used? @@ -93,6 +102,23 @@ private void validateAndAddContentObjectsToBuilder(XContentBuilder builder, Unif } + private static Map jsonStringToMap(String jsonString) throws IOException { + if (jsonString == null || jsonString.isEmpty()) { + return null; + } + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, jsonString)) { + XContentParser.Token token = parser.nextToken(); + if (token != XContentParser.Token.START_OBJECT) { + throw new IOException("Expected JSON object to start with '{', but found " + token); + } + return parser.mapStrings(); + } + } + private void buildContents(XContentBuilder builder) throws IOException { var messages = unifiedChatInput.getRequest().messages(); @@ -116,6 +142,18 @@ private void buildContents(XContentBuilder builder) throws IOException { throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); } } + + if (message.toolCalls() != null && message.toolCalls().isEmpty() == false) { + var toolCalls = message.toolCalls(); + for (var toolCall : toolCalls) { + builder.startObject(); + builder.startObject(FUNCTION_CALL); + builder.field(FUNCTION_CALL_NAME, toolCall.function().name()); + builder.field(FUNCTION_CALL_ARGS, jsonStringToMap(toolCall.function().arguments())); + builder.endObject(); + builder.endObject(); + } + } builder.endArray(); builder.endObject(); } @@ -167,7 +205,6 @@ private void buildTools(XContentBuilder builder) throws IOException { } private void buildToolConfig(XContentBuilder builder) throws IOException { - // Build the "tool_config" object (function calling config) var request = unifiedChatInput.getRequest(); UnifiedCompletionRequest.ToolChoiceObject toolChoice = (UnifiedCompletionRequest.ToolChoiceObject) request.toolChoice(); if (toolChoice == null) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index ca093b5cd27fb..30cc730f8f935 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -16,12 +16,15 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; @@ -29,8 +32,11 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index 827ed294b81e4..b417abcf40601 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -498,8 +498,13 @@ public void testParseAllFields() throws IOException { { "role": "user", "parts": [ - { - "text": "some text" + { "text": "some text" }, + { "functionCall" : { + "name": "get_delivery_date", + "args": { + "order_id" : "order_12345" + } + } } ] } @@ -522,7 +527,13 @@ public void testParseAllFields() throws IOException { } ] } - ] + ], + "toolConfig": { + "functionCallingConfig" : { + "mode": "ANY", + "allowedFunctionNames": [ "some function" ] + } + } } """; @@ -535,7 +546,7 @@ public void testParseAllFields() throws IOException { List.of( new UnifiedCompletionRequest.ToolCall( "call_62136354", - new UnifiedCompletionRequest.ToolCall.FunctionField("{'order_id': 'order_12345'}", "get_delivery_date"), + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), "function" ) ) From f1ab8ccf66be8b0a5f7758d6fcc8e9dd58f3c6b8 Mon Sep 17 00:00:00 2001 From: lhoet Date: Mon, 12 May 2025 15:45:51 -0300 Subject: [PATCH 18/38] Response handler with tests --- ...iUnifiedChatCompletionResponseHandler.java | 3 - ...iedChatCompletionResponseHandlerTests.java | 140 ++++++++++++++++++ 2 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index a519cf2c0f211..3e549b30287e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -80,7 +80,6 @@ protected Exception buildError(String message, Request request, HttpResult resul ); } - // TODO: This method was auto generated. Check that it's working properly private static Exception buildMidStreamError(Request request, String message, Exception e) { var errorResponse = GoogleVertexAiErrorResponse.fromString(message); if (errorResponse instanceof GoogleVertexAiErrorResponse gver) { @@ -142,7 +141,6 @@ static ErrorResponse fromResponse(HttpResult response) { ) { return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); } catch (Exception e) { - // Log? For now, swallow and return undefined logger.warn("Failed to parse Google Vertex AI error response body", e); } return ErrorResponse.UNDEFINED_ERROR; @@ -155,7 +153,6 @@ static ErrorResponse fromString(String response) { ) { return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); } catch (Exception e) { - // Log? For now, swallow and return undefined logger.warn("Failed to parse Google Vertex AI error string", e); } return ErrorResponse.UNDEFINED_ERROR; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..96195986574f1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.Flow; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleVertexAiUnifiedChatCompletionResponseHandlerTests extends ESTestCase { + private static final String INFERENCE_ID = "vertexAiInference"; + + private final GoogleVertexAiUnifiedChatCompletionResponseHandler responseHandler = + new GoogleVertexAiUnifiedChatCompletionResponseHandler( + "chat_completion", + (parser, xContentRegistry) -> mock() // Dummy parse function, not used in these error tests + ); + + public void testFailValidationWithAllErrorFields() throws IOException { + var responseJson = """ + { + "error": { + "code": 400, + "message": "Invalid JSON payload received.", + "status": "INVALID_ARGUMENT" + } + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(Strings.format(""" + {"error":{"code":"400","message":"Received a server error status code for request from inference entity id [%s] \ + status [500]. Error message: [Invalid JSON payload received.]","type":"INVALID_ARGUMENT"}}\ + """, INFERENCE_ID))); + } + + public void testFailValidationWithAllErrorFieldsAndDetails() throws IOException { + var responseJson = """ + { + "error": { + "code": 400, + "message": "Invalid JSON payload received.", + "status": "INVALID_ARGUMENT", + "details": [ + { "some":"value" } + ] + } + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(Strings.format(""" + {"error":{"code":"400","message":"Received a server error status code for request from inference entity id [%s] \ + status [500]. Error message: [Invalid JSON payload received.]","type":"INVALID_ARGUMENT"}}\ + """, INFERENCE_ID))); + } + + private static Request mockRequest() { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn(INFERENCE_ID); + when(request.isStreaming()).thenReturn(true); + return request; + } + + private static HttpResponse mockHttpResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String invalidResponseJson(String responseJson) throws IOException { + var exception = invalidResponse(responseJson); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } + + private Exception invalidResponse(String responseJson) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockHttpResponse(500), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + +} From 23c7d924d76b92fc6d9900cdebeddf388cdb9cc8 Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 11:34:13 -0300 Subject: [PATCH 19/38] VertexAI chat completion req entity bugfixes --- ...xAiUnifiedChatCompletionRequestEntity.java | 39 ++- ...ifiedChatCompletionRequestEntityTests.java | 224 ++++++++++++++++++ 2 files changed, 253 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index 3d89b49bba87b..6899af8ef8018 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -47,6 +47,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String FUNCTION_CALLING_CONFIG = "functionCallingConfig"; private static final String TOOL_MODE = "mode"; private static final String TOOL_MODE_ANY = "ANY"; + private static final String TOOL_MODE_AUTO = "auto"; private static final String ALLOWED_FUNCTION_NAMES = "allowedFunctionNames"; private static final String FUNCTION_CALL = "functionCall"; @@ -138,8 +139,7 @@ private void buildContents(XContentBuilder builder) throws IOException { contentObjects ); case null -> { - var errorMessage = "Google VertexAI API requires at least one text message but none were provided"; - throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); + // Content can be null and that's fine. If this case is not present, Null pointer exception will be thrown } } @@ -169,6 +169,8 @@ private void buildTools(XContentBuilder builder) throws IOException { } builder.startArray(TOOLS); + builder.startObject(); + builder.startArray(FUNCTION_DECLARATIONS); for (var tool : tools) { if (FUNCTION_TYPE.equals(tool.type()) == false) { var errorMessage = format( @@ -184,10 +186,8 @@ private void buildTools(XContentBuilder builder) throws IOException { throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); } - builder.startObject(); - builder.startArray(FUNCTION_DECLARATIONS); - builder.startObject(); + builder.startObject(); builder.field(FUNCTION_NAME, function.name()); if (Strings.hasText(function.description())) { builder.field(FUNCTION_DESCRIPTION, function.description()); @@ -198,17 +198,36 @@ private void buildTools(XContentBuilder builder) throws IOException { } builder.endObject(); - builder.endArray(); - builder.endObject(); } builder.endArray(); + builder.endObject(); + builder.endArray(); } private void buildToolConfig(XContentBuilder builder) throws IOException { var request = unifiedChatInput.getRequest(); - UnifiedCompletionRequest.ToolChoiceObject toolChoice = (UnifiedCompletionRequest.ToolChoiceObject) request.toolChoice(); - if (toolChoice == null) { - return; + + UnifiedCompletionRequest.ToolChoiceObject toolChoice; + switch (request.toolChoice()) { + case UnifiedCompletionRequest.ToolChoiceObject toolChoiceObject -> { + toolChoice = toolChoiceObject; + } + case UnifiedCompletionRequest.ToolChoiceString toolChoiceString -> { + if (toolChoiceString.value().equals(TOOL_MODE_AUTO)) { + return; + } + throw new ElasticsearchStatusException( + format( + "Tool choice value [%s] not supported by Google VertexAI ChatCompletion. Supported values: [%s]", + toolChoiceString.value(), + TOOL_MODE_AUTO + ), + RestStatus.BAD_REQUEST + ); + } + case null -> { + return; + } } if (FUNCTION_TYPE.equals(toolChoice.type()) == false) { var errorMessage = format( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index b417abcf40601..bab77ee1500a2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -587,4 +587,228 @@ public void testParseAllFields() throws IOException { String jsonString = Strings.toString(builder); assertJsonEquals(jsonString, requestJson); } + + public void testParseFunctionCallNoContent() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "model", + "parts": [ + { "functionCall" : { + "name": "get_delivery_date", + "args": { + "order_id" : "order_12345" + } + } + } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + null, + "model", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testParseToolChoiceString() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { "text": "some text" } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + + public void testParseToolChoiceInvalid_throwElasticSearchStatusException() throws IOException { + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("unsupported"), + null, + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + var statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + assertThat( + statusException.toString(), + containsString("Tool choice value [unsupported] not supported by Google VertexAI ChatCompletion.") + ); + + } + + public void testParseMultipleTools() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ + { "text": "some text" } + ] + } + ], + "tools": [ + { + "functionDeclarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + }, + { + "name": "get_current_temperature", + "description": "Get the current temperature in a location", + "parameters": { + "type": "object" + } + } + ] + } + ] + } + """; + + var request = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("some text", "text"))), + "user", + null, + null + ) + ), + "gemini-2.0", + null, + null, + null, + null, + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ), + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current temperature in a location", + "get_current_temperature", + Map.of("type", "object"), + null + ) + ) + ), + null + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } } From c45d23f5e46d6c93a8729f65e9bd1fe1e620f122 Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 13:32:08 -0300 Subject: [PATCH 20/38] Bugfix in vertex ai unified chat completion req entity --- ...xAiUnifiedChatCompletionRequestEntity.java | 11 ++- ...ifiedChatCompletionRequestEntityTests.java | 87 +++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index 6899af8ef8018..2d64c25c2aef1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -63,7 +63,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); - this.model = Objects.requireNonNull(model); // Keep the model reference + this.model = Objects.requireNonNull(model); } private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) throws IOException { @@ -94,6 +94,11 @@ private void validateAndAddContentObjectsToBuilder(XContentBuilder builder, Unif ); throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); } + + if (contentObject.text().isEmpty()) { + return; // VertexAI API does not support empty text parts + } + // We are only supporting Text messages but VertexAI supports more types: // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/Content?_gl=1*q4uxnh*_up*MQ..&gclid=CjwKCAjwwqfABhBcEiwAZJjC3uBQNP9KUMZX8AGXvFXP2rIEQSfCX9RLP5gjzx5r-4xz1daBSxM7GBoCY64QAvD_BwE&gclsrc=aw.ds#Part builder.startObject(); @@ -130,6 +135,9 @@ private void buildContents(XContentBuilder builder) throws IOException { builder.startArray(PARTS); switch (message.content()) { case UnifiedCompletionRequest.ContentString contentString -> { + if (contentString.content().isEmpty()) { + break; // VertexAI API does not support empty text parts + } builder.startObject(); builder.field(TEXT, contentString.content()); builder.endObject(); @@ -297,6 +305,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws buildToolConfig(builder); builder.endObject(); + var s = Strings.toString(builder); return builder; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index bab77ee1500a2..e8305ca4a44bd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -26,6 +26,7 @@ import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.is; public class GoogleVertexAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { @@ -646,6 +647,92 @@ public void testParseFunctionCallNoContent() throws IOException { assertJsonEquals(jsonString, requestJson); } + public void testParseFunctionCallWithEmptyStringContent() throws IOException { + String requestJson = """ + { + "contents": [ + { + "role": "model", + "parts": [ + { "functionCall" : { + "name": "get_delivery_date", + "args": { + "order_id" : "order_12345" + } + } + } + ] + } + ] + } + """; + + var requestContentObject = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + // new UnifiedCompletionRequest.ContentObject("", "text"), + new UnifiedCompletionRequest.ContentObjects(List.of(new UnifiedCompletionRequest.ContentObject("", "text"))), + "model", + null, + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + + var requestContentString = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(""), + "model", + null, + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{\"order_id\": \"order_12345\"}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gemini-2.0", + null, + null, + null, + null, + null, + null + ); + var requests = List.of(requestContentObject, requestContentString); + + for (var request : requests) { + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); + var model = createModel(); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + assertJsonEquals(jsonString, requestJson); + } + } + public void testParseToolChoiceString() throws IOException { String requestJson = """ { From a820d83fce1a8f1de9852286ffb0730d04974d52 Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 13:43:48 -0300 Subject: [PATCH 21/38] Bugfix in vertex ai unified streaming processor --- .../GoogleVertexAiUnifiedStreamingProcessor.java | 3 ++- .../GoogleVertexAiUnifiedStreamingProcessorTests.java | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java index b2c33cb294a08..1c824e1126fa7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java @@ -60,6 +60,7 @@ public class GoogleVertexAiUnifiedStreamingProcessor extends DelegatingProcessor private static final String FUNCTION_NAME_FIELD = "name"; private static final String FUNCTION_ARGS_FIELD = "args"; + private static final String CHAT_COMPLETION_CHUNK = "chat.completion.chunk"; private static final String FUNCTION_TYPE = "function"; private final BiFunction errorParser; @@ -206,7 +207,7 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice responseId, choices, modelversion, - null, + CHAT_COMPLETION_CHUNK, usageMetadataToChunk(usage) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java index 7bf5c5f2df30c..b5833823569f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java @@ -57,13 +57,14 @@ public void testJsonLiteral() { assertEquals("responseId", chunk.id()); assertEquals(1, chunk.choices().size()); + assertEquals("chat.completion.chunk", chunk.object()); var choice = chunk.choices().getFirst(); assertEquals("Elastic", choice.delta().content()); assertEquals("model", choice.delta().role()); assertEquals("gemini-2.0-flash-lite", chunk.model()); - assertEquals("MAXTOKENS", choice.finishReason()); assertEquals(0, choice.index()); // VertexAI response does not have Index. Use 0 as default + assertEquals("MAXTOKENS", choice.finishReason()); assertEquals(1, choice.delta().toolCalls().size()); var toolCall = choice.delta().toolCalls().getFirst(); From d2f09cfd49515fb4749f34f835c7cb42d6e026ad Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 13:46:49 -0300 Subject: [PATCH 22/38] Removed google aiplatform sdk --- gradle/verification-metadata.xml | 10 ---------- x-pack/plugin/inference/build.gradle | 3 --- 2 files changed, 13 deletions(-) diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index cb6d2f779f3ef..d546e80d1a8a4 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -481,11 +481,6 @@ - - - - - @@ -566,11 +561,6 @@ - - - - - diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index dc0066634713d..b0657968f00fc 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -107,9 +107,6 @@ dependencies { /* SLF4J (via AWS SDKv2) */ api "org.slf4j:slf4j-api:${versions.slf4j}" runtimeOnly "org.slf4j:slf4j-nop:${versions.slf4j}" - /* Google aiplatform SDK */ - implementation 'com.google.cloud:google-cloud-aiplatform:3.61.0' - api "com.google.api:gax:2.64.2" } tasks.named("dependencyLicenses").configure { From bda94de6bf10cbeb0e97e958848eeb1bceb9db1c Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 14:07:16 -0300 Subject: [PATCH 23/38] Renamed file to match class name for JsonArrayPartsEventParser --- .../JsonArrayPartsEventProcessor.java | 3 - ...va => JsonArrayPartsEventParserTests.java} | 96 ++++--------------- 2 files changed, 18 insertions(+), 81 deletions(-) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/{JsonPartsEventParserTests.java => JsonArrayPartsEventParserTests.java} (68%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java index f66f415dbb266..6210f1b91b3d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventProcessor.java @@ -15,7 +15,6 @@ public class JsonArrayPartsEventProcessor extends DelegatingProcessor> { private final JsonArrayPartsEventParser jsonArrayPartsEventParser; - // TODO: This class is missing unit testing public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEventParser) { this.jsonArrayPartsEventParser = jsonArrayPartsEventParser; } @@ -23,14 +22,12 @@ public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEven @Override public void next(HttpResult item) { if (item.isBodyEmpty()) { - // discard empty result and go to the next upstream().request(1); return; } var response = jsonArrayPartsEventParser.parse(item.body()); if (response.isEmpty()) { - // discard empty result and go to the next upstream().request(1); return; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java similarity index 68% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java index 43af0ca3c691f..4e0b505473a6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonPartsEventParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java @@ -15,7 +15,7 @@ import static org.hamcrest.Matchers.equalTo; -public class JsonPartsEventParserTests extends ESTestCase { +public class JsonArrayPartsEventParserTests extends ESTestCase { private void assertJsonParts(Deque actualParts, List expectedJsonStrings) { assertThat("Number of parsed parts mismatch", actualParts.size(), equalTo(expectedJsonStrings.size())); @@ -31,12 +31,13 @@ public void testParse_givenNullOrEmptyBytes_returnsEmptyDeque() { assertTrue(parser.parse(null).isEmpty()); assertTrue(parser.parse(new byte[0]).isEmpty()); - // Test with pre-existing incomplete part - parser.parse("{".getBytes(StandardCharsets.UTF_8)); // Create an incomplete part + var incompletePart = "{".getBytes(StandardCharsets.UTF_8); + parser.parse(incompletePart); assertTrue(parser.parse(null).isEmpty()); assertTrue(parser.parse(new byte[0]).isEmpty()); - // Check that the incomplete part is still there - Deque parts = parser.parse("}".getBytes(StandardCharsets.UTF_8)); + + var missingPart = "}".getBytes(StandardCharsets.UTF_8); + Deque parts = parser.parse(missingPart); assertJsonParts(parts, List.of("{}")); } @@ -52,9 +53,10 @@ public void testParse_multipleCompleteObjectsInOneChunk_returnsMultipleParts() { JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); String json1 = "{\"key1\":\"value1\"}"; String json2 = "{\"key2\":\"value2\"}"; - // Simulating a JSON array structure, the parser extracts {} + byte[] input = ("[" + json1 + "," + json2 + "]").getBytes(StandardCharsets.UTF_8); Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); } @@ -62,8 +64,10 @@ public void testParse_twoObjectsBackToBack_extractsBoth() { JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); String json1 = "{\"a\":1}"; String json2 = "{\"b\":2}"; + byte[] input = (json1 + json2).getBytes(StandardCharsets.UTF_8); Deque parts = parser.parse(input); + assertJsonParts(parts, List.of(json1, json2)); } @@ -99,14 +103,6 @@ public void testParse_multipleObjectsSomeSplit_returnsPartsIncrementally() { assertTrue("Expected no more parts from empty call", parser.parse(new byte[0]).isEmpty()); } - public void testParse_withArrayBracketsAndCommas_extractsObjects() { - JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); - String json1 = "{\"a\":1}"; - String json2 = "{\"b\":2}"; - byte[] input = (" [ " + json1 + " , " + json2 + " ] ").getBytes(StandardCharsets.UTF_8); - Deque parts = parser.parse(input); - assertJsonParts(parts, List.of(json1, json2)); - } public void testParse_nestedObjects_extractsTopLevelObject() { JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); @@ -136,7 +132,7 @@ public void testParse_endsWithIncompleteObject_buffersCorrectly() { byte[] input = (json1 + "," + partialJsonStart).getBytes(StandardCharsets.UTF_8); Deque parts = parser.parse(input); - assertJsonParts(parts, List.of(json1)); // Only the complete one + assertJsonParts(parts, List.of(json1)); String partialJsonEnd = "continue\"}"; String json2 = partialJsonStart + partialJsonEnd; @@ -160,28 +156,27 @@ public void testParse_onlyCloseBrace_ignored() { JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); byte[] input = "}".getBytes(StandardCharsets.UTF_8); Deque parts = parser.parse(input); - assertTrue(parts.isEmpty()); // Should be ignored as no open brace context + assertTrue(parts.isEmpty()); - // With preceding data parts = parser.parse("some data }".getBytes(StandardCharsets.UTF_8)); assertTrue(parts.isEmpty()); } public void testParse_mismatchedBraces_handlesGracefully() { JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); - // Extra closing brace + byte[] input1 = "{\"key\":\"val\"}}".getBytes(StandardCharsets.UTF_8); Deque parts1 = parser.parse(input1); assertJsonParts(parts1, List.of("{\"key\":\"val\"}")); // First object is fine, extra '}' ignored - // Extra opening brace at end - parser = new JsonArrayPartsEventParser(); // reset + parser = new JsonArrayPartsEventParser(); byte[] input2 = "{\"key\":\"val\"}{".getBytes(StandardCharsets.UTF_8); Deque parts2 = parser.parse(input2); assertJsonParts(parts2, List.of("{\"key\":\"val\"}")); // First object + // The last '{' should be buffered Deque parts3 = parser.parse("}".getBytes(StandardCharsets.UTF_8)); - assertJsonParts(parts3, List.of("{}")); // Completes the buffered '{' + assertJsonParts(parts3, List.of("{}")); } public void testParse_objectWithMultiByteChars_handlesCorrectly() { @@ -191,9 +186,8 @@ public void testParse_objectWithMultiByteChars_handlesCorrectly() { Deque parts = parser.parse(input); assertJsonParts(parts, List.of(json)); - // Split case - parser = new JsonArrayPartsEventParser(); // reset - String part1Str = "{\"key\":\"value_with_emoji_😊"; // Split within multi-byte char or after + parser = new JsonArrayPartsEventParser(); + String part1Str = "{\"key\":\"value_with_emoji_😊"; String part2Str = "_and_résumé\"}"; byte[] chunk1 = part1Str.getBytes(StandardCharsets.UTF_8); byte[] chunk2 = part2Str.getBytes(StandardCharsets.UTF_8); @@ -212,15 +206,12 @@ public void testParse_javadocExampleStream() { String json3 = "{\"key3\":\"val3\"}"; String json4 = "{\"some\":\"object\"}"; - // Chunk 1: [{"key":"val1"} Deque parts1 = parser.parse(("[{\"key\":\"val1\"}").getBytes(StandardCharsets.UTF_8)); assertJsonParts(parts1, List.of(json1)); - // Chunk 2: ,{"key2":"val2"} Deque parts2 = parser.parse((",{\"key2\":\"val2\"}").getBytes(StandardCharsets.UTF_8)); assertJsonParts(parts2, List.of(json2)); - // Chunk 3: ,{"key3":"val3"}, {"some":"object"}] Deque parts3 = parser.parse((",{\"key3\":\"val3\"}, {\"some\":\"object\"}]").getBytes(StandardCharsets.UTF_8)); assertJsonParts(parts3, List.of(json3, json4)); } @@ -242,55 +233,4 @@ public void testParse_dataBeforeFirstObjectAndAfterLastObject() { Deque parts = parser.parse(input); assertJsonParts(parts, List.of(json1, json2)); } - - public void testParse_incompleteObjectNeverCompleted() { - JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); - byte[] chunk1 = "{\"key\":".getBytes(StandardCharsets.UTF_8); - Deque parts1 = parser.parse(chunk1); - assertTrue(parts1.isEmpty()); - - // Send another chunk that doesn't complete the first object but starts a new one - byte[] chunk2 = "{\"anotherKey\":\"value\"}".getBytes(StandardCharsets.UTF_8); - Deque parts2 = parser.parse(chunk2); - // The incomplete "{\"key\":" is overwritten by the new complete object "{\"anotherKey\":\"value\"}" - // because objectStartIndex will be reset to the start of the new object. - // The previous incompletePart is combined, but if a new '{' is found at brace level 0, - // objectStartIndex is updated. The old incomplete part is effectively discarded if not completed. - // Let's trace: - // After chunk1: incompletePart = "{\"key\":" - // parse(chunk2): dataToProcess = "{\"key\":{\"anotherKey\":\"value\"}" - // incompletePart.reset() - // Loop: - // '{' -> objectStartIndex=0, braceLevel=1 - // ... - // ':' -> - // '{' -> objectStartIndex=7 (THIS IS THE KEY: if braceLevel is >0, objectStartIndex is NOT reset) - // So the outer object is still being tracked. - // '}' -> braceLevel becomes 1 (for inner) - // '}' -> braceLevel becomes 0 (for outer) -> emits "{\"key\":{\"anotherKey\":\"value\"}}" - // This means the test case needs to be: - // Chunk1: {"key": - // Chunk2: "value"} , {"next":1} - // Expected: {"key":"value"}, {"next":1} - - // Corrected test for incomplete object handling: - parser = new JsonArrayPartsEventParser(); // Reset - parts1 = parser.parse("{\"key\":".getBytes(StandardCharsets.UTF_8)); - assertTrue(parts1.isEmpty()); - - Deque partsAfterCompletion = parser.parse("\"value\"}".getBytes(StandardCharsets.UTF_8)); - assertJsonParts(partsAfterCompletion, List.of("{\"key\":\"value\"}")); - - // If an incomplete part is followed by non-JSON or unrelated data - parser = new JsonArrayPartsEventParser(); // Reset - parts1 = parser.parse("{\"key\":".getBytes(StandardCharsets.UTF_8)); - assertTrue(parts1.isEmpty()); - // Send some data that doesn't complete it and doesn't start a new valid object - Deque partsNoCompletion = parser.parse("some other data without braces".getBytes(StandardCharsets.UTF_8)); - assertTrue(partsNoCompletion.isEmpty()); - // The incomplete part should still be "{\"key\":some other data without braces" - // Now complete it - Deque finalParts = parser.parse("}".getBytes(StandardCharsets.UTF_8)); - assertJsonParts(finalParts, List.of("{\"key\":some other data without braces}")); - } } From 5dee072dfd41ceb52b497dd45c82bc9d96059c2c Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 14:14:27 -0300 Subject: [PATCH 24/38] Updated rate limit settings for vertex ai --- .../GoogleVertexAiChatCompletionServiceSettings.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java index 6f733ea71ef50..a59954d6a51a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java @@ -38,8 +38,6 @@ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXConten public static final String NAME = "google_vertex_ai_chatcompletion_service_settings"; - // TODO: Other fields can be missing here. Mostly the ones the ones that are described here - // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent?_gl=1*15nhtzk*_up*MQ..&gclid=CjwKCAjwwqfABhBcEiwAZJjC3uBQNP9KUMZX8AGXvFXP2rIEQSfCX9RLP5gjzx5r-4xz1daBSxM7GBoCY64QAvD_BwE&gclsrc=aw.ds private final String location; private final String modelId; private final String projectId; @@ -47,8 +45,7 @@ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXConten private final RateLimitSettings rateLimitSettings; // https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas - // TODO: this may be wrong. Double check before submitting the PR) - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2000); + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000); @Override protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { From 2f757885410f509d15e857b5bfb0bf28affcecbc Mon Sep 17 00:00:00 2001 From: lhoet Date: Tue, 13 May 2025 15:31:42 -0300 Subject: [PATCH 25/38] Deleted GoogleVertexAiChatCompletionTaskSettings --- .../GoogleVertexAiChatCompletionModel.java | 9 +-- ...gleVertexAiChatCompletionTaskSettings.java | 57 ------------------- ...oogleVertexAiChatCompletionModelTests.java | 21 ++++--- 3 files changed, 15 insertions(+), 72 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java index dc436fedca35d..3574f9a051acb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -9,6 +9,7 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -44,7 +45,7 @@ public GoogleVertexAiChatCompletionModel( taskType, service, GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context), - GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings), + new EmptyTaskSettings(), GoogleVertexAiSecretSettings.fromMap(secrets) ); } @@ -54,7 +55,7 @@ public GoogleVertexAiChatCompletionModel( TaskType taskType, String service, GoogleVertexAiChatCompletionServiceSettings serviceSettings, - GoogleVertexAiChatCompletionTaskSettings taskSettings, + EmptyTaskSettings taskSettings, @Nullable GoogleVertexAiSecretSettings secrets ) { super( @@ -113,8 +114,8 @@ public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() { } @Override - public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() { - return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings(); + public EmptyTaskSettings getTaskSettings() { + return (EmptyTaskSettings) super.getTaskSettings(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java deleted file mode 100644 index 7f78d83198db7..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.googlevertexai.completion; - -import org.apache.commons.lang3.NotImplementedException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Map; - -// TODO: This class may no be needed. Keeping this class here to keep the compiler happy, but if not needed we could replace it with `EmptyTaskSettings` -public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings { - public static final String NAME = "google_vertex_ai_chatcompletion_task_settings"; - - @Override - public boolean isEmpty() { - return true; - } - - @Override - public TaskSettings updatedTaskSettings(Map newSettings) { - return this; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED; - } - - @Override - public void writeTo(StreamOutput out) throws IOException {} - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.endObject(); - return builder; - } - - public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map map) { - return new GoogleVertexAiChatCompletionTaskSettings(); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java index a4950911a804e..916857952b988 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.completion; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; @@ -46,10 +47,8 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); - // Check that the model ID is overridden assertThat(overriddenModel.getServiceSettings().modelId(), is("gemini-flash")); - // Check that other settings remain the same assertThat(overriddenModel, not(sameInstance(model))); assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); @@ -62,7 +61,7 @@ public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenReques var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), - null, // Not overriding model + null, null, null, null, @@ -73,16 +72,14 @@ public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenReques var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, request); - // Check that the model ID is NOT overridden assertThat(overriddenModel.getServiceSettings().modelId(), is(DEFAULT_MODEL_ID)); - // Check that other settings remain the same + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); - assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); // Task settings shouldn't change - // Since nothing changed in service settings, it *could* return the same instance, - // but the current implementation always creates a new one. Let's assert it's not the same. + assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); + assertThat(overriddenModel, not(sameInstance(model))); } @@ -91,7 +88,8 @@ public void testBuildUri() throws URISyntaxException { String projectId = "my-gcp-project"; String model = "gemini-1.5-flash-001"; URI expectedUri = new URI( - "https://us-east1-aiplatform.googleapis.com/v1/projects/my-gcp-project/locations/global/publishers/google/models/gemini-1.5-flash-001:streamGenerateContent" + "https://us-east1-aiplatform.googleapis.com/v1/projects/my-gcp-project" + + "/locations/global/publishers/google/models/gemini-1.5-flash-001:streamGenerateContent" ); URI actualUri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); assertThat(actualUri, is(expectedUri)); @@ -102,7 +100,8 @@ public void testBuildUri_WithDifferentValues() throws URISyntaxException { String projectId = "another-project-123"; String model = "gemini-pro"; URI expectedUri = new URI( - "https://europe-west2-aiplatform.googleapis.com/v1/projects/another-project-123/locations/global/publishers/google/models/gemini-pro:streamGenerateContent" + "https://europe-west2-aiplatform.googleapis.com/v1/projects/another-project-123/" + + "locations/global/publishers/google/models/gemini-pro:streamGenerateContent" ); URI actualUri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); assertThat(actualUri, is(expectedUri)); @@ -120,7 +119,7 @@ public static GoogleVertexAiChatCompletionModel createCompletionModel( TaskType.CHAT_COMPLETION, "google_vertex_ai", new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings), - new GoogleVertexAiChatCompletionTaskSettings(), + new EmptyTaskSettings(), new GoogleVertexAiSecretSettings(new SecureString(apiKey.toCharArray())) ); } From b50c911cb4c51ffbc1f4b2e460621bc52c093682 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 14 May 2025 09:21:33 -0300 Subject: [PATCH 26/38] VertexAI Unified chat completion request tests --- ...exAiUnifiedChatCompletionRequestTests.java | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java new file mode 100644 index 0000000000000..bd080ab7fe9aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests.createCompletionModel; +import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiUnifiedChatCompletionRequestTests extends ESTestCase { + + private static final String AUTH_HEADER_VALUE = "Bearer foo"; + private static final String TEST_PROJECT_ID = "test-project"; + private static final String TEST_MODEL_ID = "chat-bison"; + private static final String TEST_LOCATION = "us-central1"; + private static final String TEST_API_KEY = "apikey"; + + public void testA() { + var model = createCompletionModel(TEST_PROJECT_ID, TEST_LOCATION, TEST_MODEL_ID, TEST_API_KEY, new RateLimitSettings(100)); + var input = buildUnifiedChatCompletionInput(List.of("Hello")); + + var request = createRequest(input, model); + var httpRequest = request.createHttpRequest(); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedUrl = Strings.format("https://%s-aiplatform.googleapis.com", TEST_LOCATION); + assertThat(httpPost.getURI().toString(), startsWith(expectedUrl)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + } + + private static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + UnifiedChatInput input, + GoogleVertexAiChatCompletionModel model + ) { + return new GoogleVertexAiUnifiedChatCompletionRequestWithoutAuth(input, model); + } + + private static UnifiedChatInput buildUnifiedChatCompletionInput(List messages) { + var requestMessages = messages.stream() + .map( + (userStringMessage) -> new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(userStringMessage), + "user", + null, + null + ) + ) + .toList(); + + var request = new UnifiedCompletionRequest(requestMessages, "gemini-2.0", null, null, null, null, null, null); + return new UnifiedChatInput(request, true); + } + + private static class GoogleVertexAiUnifiedChatCompletionRequestWithoutAuth extends GoogleVertexAiUnifiedChatCompletionRequest { + GoogleVertexAiUnifiedChatCompletionRequestWithoutAuth(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + super(unifiedChatInput, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + } +} From d6ae90f33eec4a08f4c2eace715dc1074b3e3b05 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 14 May 2025 10:00:44 -0300 Subject: [PATCH 27/38] Fixed some tests --- ...gleVertexAiChatCompletionRequestTests.java | 110 ------------------ ...exAiUnifiedChatCompletionRequestTests.java | 94 +++++++++------ ...exAiChatCompletionResponseEntityTests.java | 26 +---- 3 files changed, 61 insertions(+), 169 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java deleted file mode 100644 index 338b827a43920..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiChatCompletionRequestTests.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.googlevertexai.request; - -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpPost; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; -import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; - -import java.io.IOException; -import java.net.URI; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.hamcrest.Matchers.aMapWithSize; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; - -public class GoogleVertexAiChatCompletionRequestTests extends ESTestCase { - - private static final String AUTH_HEADER_VALUE = "Bearer foo"; - - // TODO: add more test here to check the generation configuration, different role models, etc - - public void testCreateRequest_Default() throws IOException { - var modelId = "gemini-pro"; - var projectId = "test-project"; - var location = "us-central1"; - - var messages = List.of("Hello Gemini!"); - - var request = createRequest(projectId, location, modelId, messages, null, null); - var httpRequest = request.createHttpRequest(); - var httpPost = (HttpPost) httpRequest.httpRequestBase(); - - var uri = URI.create( - String.format( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", - location, - projectId, - modelId - ) - ); - - assertThat(httpPost.getURI(), equalTo(uri)); - assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); - - var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(1)); - assertThat( - requestMap, - equalTo(Map.of("contents", List.of(Map.of("role", "user", "parts", List.of(Map.of("text", messages.getFirst())))))) - ); - - } - - public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( - UnifiedChatInput input, - GoogleVertexAiChatCompletionModel model - ) { - return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(input, model); - } - - public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( - String projectId, - String location, - String modelId, - List messages, - @Nullable String apiKey, - @Nullable RateLimitSettings rateLimitSettings - ) { - var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( - projectId, - location, - modelId, - Objects.requireNonNullElse(apiKey, "default-api-key"), - Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)) - ); - var unifiedChatInput = new UnifiedChatInput(messages, "user", true); - - return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(unifiedChatInput, model); - } - - /** - * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} - */ - private static class GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest extends GoogleVertexAiUnifiedChatCompletionRequest { - GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { - super(unifiedChatInput, model); - } - - @Override - public void decorateWithAuth(HttpPost httpPost) { - httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java index bd080ab7fe9aa..c0eee7fb885a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java @@ -9,69 +9,95 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; -import org.elasticsearch.core.Strings; -import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; -import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.io.IOException; +import java.net.URI; import java.util.List; +import java.util.Map; +import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests.createCompletionModel; -import static org.hamcrest.Matchers.startsWith; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; public class GoogleVertexAiUnifiedChatCompletionRequestTests extends ESTestCase { private static final String AUTH_HEADER_VALUE = "Bearer foo"; - private static final String TEST_PROJECT_ID = "test-project"; - private static final String TEST_MODEL_ID = "chat-bison"; - private static final String TEST_LOCATION = "us-central1"; - private static final String TEST_API_KEY = "apikey"; - public void testA() { - var model = createCompletionModel(TEST_PROJECT_ID, TEST_LOCATION, TEST_MODEL_ID, TEST_API_KEY, new RateLimitSettings(100)); - var input = buildUnifiedChatCompletionInput(List.of("Hello")); - var request = createRequest(input, model); - var httpRequest = request.createHttpRequest(); + public void testCreateRequest_Default() throws IOException { + var modelId = "gemini-pro"; + var projectId = "test-project"; + var location = "us-central1"; + var messages = List.of("Hello Gemini!"); + + var request = createRequest(projectId, location, modelId, messages, null, null); + var httpRequest = request.createHttpRequest(); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - var expectedUrl = Strings.format("https://%s-aiplatform.googleapis.com", TEST_LOCATION); - assertThat(httpPost.getURI().toString(), startsWith(expectedUrl)); + var uri = URI.create( + String.format( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", + location, + projectId, + modelId + ) + ); + + assertThat(httpPost.getURI(), equalTo(uri)); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap, + equalTo(Map.of("contents", List.of(Map.of("role", "user", "parts", List.of(Map.of("text", messages.getFirst())))))) + ); + } - private static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( UnifiedChatInput input, GoogleVertexAiChatCompletionModel model ) { - return new GoogleVertexAiUnifiedChatCompletionRequestWithoutAuth(input, model); + return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(input, model); } - private static UnifiedChatInput buildUnifiedChatCompletionInput(List messages) { - var requestMessages = messages.stream() - .map( - (userStringMessage) -> new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(userStringMessage), - "user", - null, - null - ) - ) - .toList(); - - var request = new UnifiedCompletionRequest(requestMessages, "gemini-2.0", null, null, null, null, null, null); - return new UnifiedChatInput(request, true); + public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( + String projectId, + String location, + String modelId, + List messages, + @Nullable String apiKey, + @Nullable RateLimitSettings rateLimitSettings + ) { + var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( + projectId, + location, + modelId, + Objects.requireNonNullElse(apiKey, "default-api-key"), + Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)) + ); + var unifiedChatInput = new UnifiedChatInput(messages, "user", true); + + return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(unifiedChatInput, model); } - private static class GoogleVertexAiUnifiedChatCompletionRequestWithoutAuth extends GoogleVertexAiUnifiedChatCompletionRequest { - GoogleVertexAiUnifiedChatCompletionRequestWithoutAuth(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + /** + * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} + */ + private static class GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest extends GoogleVertexAiUnifiedChatCompletionRequest { + GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { super(unifiedChatInput, model); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java index 36a365170eb3a..c8573b7708209 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java @@ -61,30 +61,6 @@ public void testFromResponse_CreatesResultsForMultipleChunks() throws IOExceptio assertThat(chatCompletionResults.getResults().getFirst().content(), is("Hello World")); } - public void testFromResponse_HandlesPartWithMissingText() throws IOException { - // Since text is optionalConstructorArg, missing text results in null, which is skipped by extractText - String responseJson = """ - [ - { - "candidates": [ - { - "content": { - "parts": [ { "not_text": "hello" } ] - } - } - ] - } - ] - """; - - ChatCompletionResults chatCompletionResults = GoogleVertexAiChatCompletionResponseEntity.fromResponse( - mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ); - - assertThat(chatCompletionResults.getResults().size(), is(1)); - assertThat(chatCompletionResults.getResults().getFirst().content(), is("")); - } public void testFromResponse_FailsWhenChunkMissingCandidates() { // Parser ignores unknown fields, but expects 'candidates' for the constructor @@ -97,7 +73,7 @@ public void testFromResponse_FailsWhenChunkMissingCandidates() { """; var thrownException = expectThrows( - IllegalArgumentException.class, // ConstructingObjectParser throws this when required args are missing + IllegalArgumentException.class, () -> GoogleVertexAiChatCompletionResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) From cbb387f0b96c9998f475f3d9eab73874e321c8f0 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 14 May 2025 11:50:58 -0300 Subject: [PATCH 28/38] Fixed GoogleAIService get configuration tests --- .../googlevertexai/GoogleVertexAiSecretSettings.java | 3 ++- .../googlevertexai/GoogleVertexAiService.java | 1 - .../googlevertexai/GoogleVertexAiServiceTests.java | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java index 9a39e200368cf..b68f5b25ff9e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -124,7 +124,8 @@ public static Map get() { var configurationMap = new HashMap(); configurationMap.put( SERVICE_ACCOUNT_JSON, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION)) + .setDescription( "API Key for the provider you're connecting to." ) .setLabel("Credentials JSON") diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 8f67a615d2c0a..849f498e1a0c6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -196,7 +196,6 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } - // TODO: Since we added a task type we need to change this? @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 30cc730f8f935..41295ccac1cef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -933,7 +933,7 @@ public void testGetConfiguration() throws Exception { { "service": "googlevertexai", "name": "Google Vertex AI", - "task_types": ["text_embedding", "rerank"], + "task_types": ["text_embedding", "rerank", "chat_completion"], "configurations": { "service_account_json": { "description": "API Key for the provider you're connecting to.", @@ -942,7 +942,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "project_id": { "description": "The GCP Project ID which has Vertex AI API(s) enabled. For more information on the URL, refer to the {geminiVertexAIDocs}.", @@ -951,7 +951,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "location": { "description": "Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.", @@ -960,7 +960,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -969,7 +969,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "model_id": { "description": "ID of the LLM you're using.", @@ -978,7 +978,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] } } } From 7e1c970a98e0cd8dde72e60755efa3d07013e117 Mon Sep 17 00:00:00 2001 From: lhoet Date: Wed, 14 May 2025 12:22:40 -0300 Subject: [PATCH 29/38] GoogleVertexAiCompletion action tests --- ...oogleVertexAiCompletionRequestManager.java | 2 +- ...texAiUnifiedChatCompletionActionTests.java | 131 ++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java index 172f5f0f43e5d..678ec08c8b937 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java @@ -39,7 +39,7 @@ private static ResponseHandler createGoogleVertexAiResponseHandler() { private final GoogleVertexAiChatCompletionModel model; - private GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) { + public GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) { super(threadPool, model, RateLimitGrouping.of(model)); this.model = model; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java new file mode 100644 index 0000000000000..38fd50c0b090d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiCompletionRequestManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiUnifiedChatCompletionActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + private static UnifiedChatInput createUnifiedChatInput(List messages, String role) { + boolean stream = true; + return new UnifiedChatInput(messages, role, stream); + } + + // Successful case would typically be tested via end-to-end notebook tests in AppEx repo + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listenerArg = invocation.getArgument(3); + listenerArg.onFailure(new IllegalStateException("failed")); + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("us-central1", "test-project-id", "chat-bison", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(createUnifiedChatInput(List.of("test query"), "user"), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send Google Vertex AI chat completion request. Cause: failed")); + } + + private ExecutableAction createAction(String location, String projectId, String actualModelId, Sender sender) { + var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( + projectId, + location, + actualModelId, + "api-key", + new RateLimitSettings(100) + ); + + var requestManager = new GoogleVertexAiCompletionRequestManager(model, threadPool); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI chat completion"); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); + } + +} From 5ab716f40e5a6d0beff407b182df0916ff094749 Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 15 May 2025 10:11:04 -0300 Subject: [PATCH 30/38] Formatting --- .../GoogleVertexAiSecretSettings.java | 4 +- .../googlevertexai/GoogleVertexAiService.java | 1 - ...iUnifiedChatCompletionResponseHandler.java | 4 +- ...ogleVertexAiUnifiedStreamingProcessor.java | 62 ++++++++----------- .../action/GoogleVertexAiActionCreator.java | 1 - .../GoogleVertexAiChatCompletionModel.java | 10 +-- ...VertexAiChatCompletionServiceSettings.java | 10 +-- ...xAiUnifiedChatCompletionRequestEntity.java | 14 ++--- ...eVertexAiChatCompletionResponseEntity.java | 2 +- .../JsonArrayPartsEventParserTests.java | 1 - .../GoogleVertexAiServiceTests.java | 3 +- ...iedChatCompletionResponseHandlerTests.java | 8 --- ...ifiedChatCompletionRequestEntityTests.java | 2 - ...exAiUnifiedChatCompletionRequestTests.java | 1 - ...exAiChatCompletionResponseEntityTests.java | 1 - 15 files changed, 45 insertions(+), 79 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java index b68f5b25ff9e2..1abf1db642932 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -125,9 +125,7 @@ public static Map get() { configurationMap.put( SERVICE_ACCOUNT_JSON, new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION)) - .setDescription( - "API Key for the provider you're connecting to." - ) + .setDescription("API Key for the provider you're connecting to.") .setLabel("Credentials JSON") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 849f498e1a0c6..ea29bd9fbf5ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -75,7 +75,6 @@ public class GoogleVertexAiService extends SenderService { TaskType.CHAT_COMPLETION ); - public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, InputType.SEARCH, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 3e549b30287e7..c8fd4e3238d4c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java index 1c824e1126fa7..d0743fe24b5ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java @@ -185,33 +185,28 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, candidate.finishReason(), candidate.index()); } + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "google_vertexai_chat_completion_chunk", - true, - args -> { - List candidates = (List) args[0]; - UsageMetadata usage = (UsageMetadata) args[1]; - String modelversion = (String) args[2]; - String responseId = (String) args[3]; - - - boolean candidatesIsEmpty = candidates == null || candidates.isEmpty(); - List choices = candidatesIsEmpty - ? Collections.emptyList() - : candidates.stream().map(GoogleVertexAiChatCompletionChunkParser::candidateToChoice).toList(); - - - return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - responseId, - choices, - modelversion, - CHAT_COMPLETION_CHUNK, - usageMetadataToChunk(usage) - ); - } - ); + new ConstructingObjectParser<>("google_vertexai_chat_completion_chunk", true, args -> { + List candidates = (List) args[0]; + UsageMetadata usage = (UsageMetadata) args[1]; + String modelversion = (String) args[2]; + String responseId = (String) args[3]; + + boolean candidatesIsEmpty = candidates == null || candidates.isEmpty(); + List choices = candidatesIsEmpty + ? Collections.emptyList() + : candidates.stream().map(GoogleVertexAiChatCompletionChunkParser::candidateToChoice).toList(); + + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + responseId, + choices, + modelversion, + CHAT_COMPLETION_CHUNK, + usageMetadataToChunk(usage) + ); + }); static { PARSER.declareObjectArray( @@ -238,16 +233,12 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XC private record Candidate(Content content, String finishReason, int index) {} private static class CandidateParser { - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "candidate", - true, - args -> { - var content = (Content) args[0]; - var finishReason = (String) args[1]; - var index = args[2] == null ? 0 : (int) args[2]; - return new Candidate(content, finishReason, index); - } - ); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("candidate", true, args -> { + var content = (Content) args[0]; + var finishReason = (String) args[1]; + var index = args[2] == null ? 0 : (int) args[2]; + return new Candidate(content, finishReason, index); + }); static { PARSER.declareObject( @@ -345,6 +336,7 @@ public static FunctionCall parse(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } } + private record UsageMetadata(int promptTokenCount, int candidatesTokenCount, int totalTokenCount) {} private static class UsageMetadataParser { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index 5f1a703496ced..80128efee33f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.action; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java index 3574f9a051acb..b351028c1413d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -23,9 +23,9 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; +import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -import java.net.URI; import java.util.Objects; import static org.elasticsearch.core.Strings.format; @@ -90,14 +90,6 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM ); } - public GoogleVertexAiChatCompletionModel( - ModelConfigurations configurations, - ModelSecrets secrets, - GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings - ) { - super(configurations, secrets, rateLimitServiceSettings); - } - @Override public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java index a59954d6a51a7..8902a0a332a63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java @@ -7,17 +7,17 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.completion; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; @@ -27,9 +27,9 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; -import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index 2d64c25c2aef1..36f80667c521c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -73,13 +73,12 @@ private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) thro return messageRoleLowered; } - var errorMessage = - format( - "Role [%s] not supported by Google VertexAI ChatCompletion. Supported roles: [%s, %s]", - messageRole, - USER_ROLE, - MODEL_ROLE - ); + var errorMessage = format( + "Role [%s] not supported by Google VertexAI ChatCompletion. Supported roles: [%s, %s]", + messageRole, + USER_ROLE, + MODEL_ROLE + ); throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); } @@ -194,7 +193,6 @@ private void buildTools(XContentBuilder builder) throws IOException { throw new ElasticsearchStatusException(errorMessage, RestStatus.BAD_REQUEST); } - builder.startObject(); builder.field(FUNCTION_NAME, function.name()); if (Strings.hasText(function.description())) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java index 3022dde69c939..f6f1c7827e35c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntity.java @@ -21,9 +21,9 @@ import java.util.List; import java.util.Optional; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; public class GoogleVertexAiChatCompletionResponseEntity { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java index 4e0b505473a6b..ecc21b5a06ed6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/JsonArrayPartsEventParserTests.java @@ -103,7 +103,6 @@ public void testParse_multipleObjectsSomeSplit_returnsPartsIncrementally() { assertTrue("Expected no more parts from empty call", parser.parse(new byte[0]).isEmpty()); } - public void testParse_nestedObjects_extractsTopLevelObject() { JsonArrayPartsEventParser parser = new JsonArrayPartsEventParser(); String json = "{\"outer_key\":{\"inner_key\":\"value\"},\"another_key\":\"val\"}"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 41295ccac1cef..7ff8bf7c144f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -6,8 +6,10 @@ */ package org.elasticsearch.xpack.inference.services.googlevertexai; + import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -48,7 +50,6 @@ import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; -import org.elasticsearch.action.support.PlainActionFuture; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java index 96195986574f1..ea60c4653589e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java @@ -12,10 +12,8 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; @@ -23,17 +21,11 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.concurrent.Flow; import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.isA; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index e8305ca4a44bd..af4bc4f9abad0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -26,8 +26,6 @@ import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.endsWith; -import static org.hamcrest.Matchers.is; public class GoogleVertexAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java index c0eee7fb885a0..f30232ff14a58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java @@ -32,7 +32,6 @@ public class GoogleVertexAiUnifiedChatCompletionRequestTests extends ESTestCase private static final String AUTH_HEADER_VALUE = "Bearer foo"; - public void testCreateRequest_Default() throws IOException { var modelId = "gemini-pro"; var projectId = "test-project"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java index c8573b7708209..239b7b48633ce 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiChatCompletionResponseEntityTests.java @@ -61,7 +61,6 @@ public void testFromResponse_CreatesResultsForMultipleChunks() throws IOExceptio assertThat(chatCompletionResults.getResults().getFirst().content(), is("Hello World")); } - public void testFromResponse_FailsWhenChunkMissingCandidates() { // Parser ignores unknown fields, but expects 'candidates' for the constructor String responseJson = """ From 28aa4640e1512f3cc9cd05a8710e30bec921ac7b Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 15 May 2025 10:16:08 -0300 Subject: [PATCH 31/38] Code style fix --- .../GoogleVertexAiChatCompletionModel.java | 1 - ...eVertexAiUnifiedChatCompletionRequest.java | 2 +- ...xAiUnifiedChatCompletionRequestEntity.java | 15 +-- ...ifiedChatCompletionRequestEntityTests.java | 103 +++++------------- 4 files changed, 33 insertions(+), 88 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java index b351028c1413d..fa211f00a7750 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -17,7 +17,6 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel; -import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor; import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java index 6541569887c1a..7b20e71099e66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -35,7 +35,7 @@ public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatIn public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(model.uri()); - var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index 36f80667c521c..c375f503aaa24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import java.io.IOException; import java.util.Map; @@ -55,18 +54,16 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String FUNCTION_CALL_ARGS = "args"; private final UnifiedChatInput unifiedChatInput; - private final GoogleVertexAiChatCompletionModel model; // TODO: This is not being used? private static final String USER_ROLE = "user"; private static final String MODEL_ROLE = "model"; private static final String STOP_SEQUENCES = "stopSequences"; - public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { + public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); - this.model = Objects.requireNonNull(model); } - private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) throws IOException { + private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) { var messageRoleLowered = messageRole.toLowerCase(); if (messageRoleLowered.equals(USER_ROLE) || messageRoleLowered.equals(MODEL_ROLE)) { @@ -98,8 +95,7 @@ private void validateAndAddContentObjectsToBuilder(XContentBuilder builder, Unif return; // VertexAI API does not support empty text parts } - // We are only supporting Text messages but VertexAI supports more types: - // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/Content?_gl=1*q4uxnh*_up*MQ..&gclid=CjwKCAjwwqfABhBcEiwAZJjC3uBQNP9KUMZX8AGXvFXP2rIEQSfCX9RLP5gjzx5r-4xz1daBSxM7GBoCY64QAvD_BwE&gclsrc=aw.ds#Part + // We are only supporting Text messages for now builder.startObject(); builder.field(TEXT, contentObject.text()); builder.endObject(); @@ -215,9 +211,7 @@ private void buildToolConfig(XContentBuilder builder) throws IOException { UnifiedCompletionRequest.ToolChoiceObject toolChoice; switch (request.toolChoice()) { - case UnifiedCompletionRequest.ToolChoiceObject toolChoiceObject -> { - toolChoice = toolChoiceObject; - } + case UnifiedCompletionRequest.ToolChoiceObject toolChoiceObject -> toolChoice = toolChoiceObject; case UnifiedCompletionRequest.ToolChoiceString toolChoiceString -> { if (toolChoiceString.value().equals(TOOL_MODE_AUTO)) { return; @@ -303,7 +297,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws buildToolConfig(builder); builder.endObject(); - var s = Strings.toString(builder); return builder; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index af4bc4f9abad0..cc45b2f286c83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -50,11 +50,7 @@ public void testBasicSerialization_SingleMessage() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); // stream doesn't affect VertexAI request body - var model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -98,10 +94,7 @@ public void testSerialization_MultipleMessages() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -159,10 +152,7 @@ public void testSerialization_Tools() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -229,10 +219,7 @@ public void testSerialization_ToolsChoice() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -289,10 +276,7 @@ public void testSerialization_WithAllGenerationConfig() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -338,10 +322,7 @@ public void testSerialization_WithSomeGenerationConfig() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -377,10 +358,7 @@ public void testSerialization_NoGenerationConfig() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -417,10 +395,7 @@ public void testSerialization_WithContentObjects() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -454,13 +429,12 @@ public void testError_UnsupportedRole() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); - XContentBuilder builder = JsonXContent.contentBuilder(); - var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + ElasticsearchStatusException statusException; + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + } assertEquals(RestStatus.BAD_REQUEST, statusException.status()); assertThat(statusException.toString(), containsString("Role [system] not supported by Google VertexAI ChatCompletion")); @@ -478,13 +452,12 @@ public void testError_UnsupportedContentObjectType() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); GoogleVertexAiChatCompletionModel model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); - XContentBuilder builder = JsonXContent.contentBuilder(); - var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + ElasticsearchStatusException statusException; + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + } assertEquals(RestStatus.BAD_REQUEST, statusException.status()); assertThat(statusException.toString(), containsString("Type [image_url] not supported by Google VertexAI ChatCompletion")); @@ -574,11 +547,7 @@ public void testParseAllFields() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - var model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -632,11 +601,7 @@ public void testParseFunctionCallNoContent() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - var model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -717,10 +682,8 @@ public void testParseFunctionCallWithEmptyStringContent() throws IOException { for (var request : requests) { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - var model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model + unifiedChatInput ); XContentBuilder builder = JsonXContent.contentBuilder(); @@ -764,11 +727,7 @@ public void testParseToolChoiceString() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - var model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -797,14 +756,12 @@ public void testParseToolChoiceInvalid_throwElasticSearchStatusException() throw ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - var model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); - XContentBuilder builder = JsonXContent.contentBuilder(); - var statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + ElasticsearchStatusException statusException; + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + } assertThat( statusException.toString(), containsString("Tool choice value [unsupported] not supported by Google VertexAI ChatCompletion.") @@ -884,11 +841,7 @@ public void testParseMultipleTools() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - var model = createModel(); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput, - model - ); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); From 22793919b8ee9c80a2dd42921b6a3ad5b7fa99dc Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 15 May 2025 10:17:24 -0300 Subject: [PATCH 32/38] Removed unnused variables --- ...eVertexAiUnifiedChatCompletionRequestEntityTests.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index cc45b2f286c83..dca7d0e74cbbb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -92,7 +92,6 @@ public void testSerialization_MultipleMessages() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messages); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -150,7 +149,6 @@ public void testSerialization_Tools() throws IOException { null ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -217,7 +215,6 @@ public void testSerialization_ToolsChoice() throws IOException { null ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -274,7 +271,6 @@ public void testSerialization_WithAllGenerationConfig() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -320,7 +316,6 @@ public void testSerialization_WithSomeGenerationConfig() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -356,7 +351,6 @@ public void testSerialization_NoGenerationConfig() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -393,7 +387,6 @@ public void testSerialization_WithContentObjects() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -427,7 +420,6 @@ public void testError_UnsupportedRole() throws IOException { ); var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -450,7 +442,6 @@ public void testError_UnsupportedContentObjectType() throws IOException { ); var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); - GoogleVertexAiChatCompletionModel model = createModel(); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); From 85af5c039b29b61a86a0ec5388fdad8d5d3a08d7 Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 15 May 2025 12:07:16 -0300 Subject: [PATCH 33/38] Function call id fixed --- .../GoogleVertexAiUnifiedChatCompletionRequestEntity.java | 2 ++ .../GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java | 1 + 2 files changed, 3 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index c375f503aaa24..b685b39bf6574 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -52,6 +52,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String FUNCTION_CALL = "functionCall"; private static final String FUNCTION_CALL_NAME = "name"; private static final String FUNCTION_CALL_ARGS = "args"; + private static final String FUNCTION_CALL_ID = "id"; private final UnifiedChatInput unifiedChatInput; @@ -152,6 +153,7 @@ private void buildContents(XContentBuilder builder) throws IOException { builder.startObject(); builder.startObject(FUNCTION_CALL); builder.field(FUNCTION_CALL_NAME, toolCall.function().name()); + builder.field(FUNCTION_CALL_ID, toolCall.id()); builder.field(FUNCTION_CALL_ARGS, jsonStringToMap(toolCall.function().arguments())); builder.endObject(); builder.endObject(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index dca7d0e74cbbb..0276f59943c4d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -609,6 +609,7 @@ public void testParseFunctionCallWithEmptyStringContent() throws IOException { "role": "model", "parts": [ { "functionCall" : { + "id" : "call_62136354", "name": "get_delivery_date", "args": { "order_id" : "order_12345" From 16c01b0f3d9d85d262cb67ff47191f91b51a3689 Mon Sep 17 00:00:00 2001 From: lhoet Date: Thu, 15 May 2025 14:25:24 -0300 Subject: [PATCH 34/38] Bugfix --- .../GoogleVertexAiUnifiedStreamingProcessor.java | 16 +++++++++++----- ...leVertexAiUnifiedStreamingProcessorTests.java | 4 +--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java index d0743fe24b5ec..551890bd9eca3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java @@ -31,6 +31,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.LinkedBlockingDeque; import java.util.function.BiFunction; @@ -343,11 +344,16 @@ private static class UsageMetadataParser { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( USAGE_METADATA_FIELD, true, - args -> new UsageMetadata( - args[0] == null ? 0 : (int) args[0], - args[1] == null ? 0 : (int) args[1], - args[2] == null ? 0 : (int) args[2] - ) + args -> { + if (Objects.isNull(args[0]) && Objects.isNull(args[1]) && Objects.isNull(args[2])) { + return null; + } + return new UsageMetadata( + args[0] == null ? 0 : (int) args[0], + args[1] == null ? 0 : (int) args[1], + args[2] == null ? 0 : (int) args[2] + ); + } ); static { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java index b5833823569f6..4915313e36d6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java @@ -110,7 +110,6 @@ public void testJsonLiteral_optionalTopLevelFieldsMissing() { assertEquals("STOP", choice.finishReason()); assertEquals(0, choice.index()); assertNull(choice.delta().toolCalls()); - assertNull(chunk.usage()); } catch (IOException e) { fail("IOException during test: " + e.getMessage()); @@ -192,10 +191,9 @@ public void testJsonLiteral_multipleTextParts() { assertEquals("This is the first part. This is the second part.", choice.delta().content()); assertEquals("STOP", choice.finishReason()); assertEquals(0, choice.index()); - assertNull(choice.delta().toolCalls()); // No function calls in this test case + assertNull(choice.delta().toolCalls()); assertNull(chunk.model()); assertNull(chunk.usage()); - } catch (IOException e) { fail("IOException during test: " + e.getMessage()); } From 6cc165b97dd28cbbdd9f152881f77b30630b0267 Mon Sep 17 00:00:00 2001 From: lhoet Date: Fri, 16 May 2025 09:42:28 -0300 Subject: [PATCH 35/38] Testfix --- .../googlevertexai/GoogleVertexAiService.java | 2 +- ...ifiedChatCompletionRequestEntityTests.java | 21 ++++++++----------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 344adf144c606..cfdb548b7917e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -383,7 +383,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( LOCATION, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( "Please provide the GCP region where the Vertex AI API(s) is enabled. " + "For more information, refer to the {geminiVertexAIDocs}." ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index 0276f59943c4d..707513ddff6b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -423,10 +423,8 @@ public void testError_UnsupportedRole() throws IOException { GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); - ElasticsearchStatusException statusException; - try (XContentBuilder builder = JsonXContent.contentBuilder()) { - statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); - } + var builder = JsonXContent.contentBuilder(); + var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); assertEquals(RestStatus.BAD_REQUEST, statusException.status()); assertThat(statusException.toString(), containsString("Role [system] not supported by Google VertexAI ChatCompletion")); @@ -445,10 +443,8 @@ public void testError_UnsupportedContentObjectType() throws IOException { GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); - ElasticsearchStatusException statusException; - try (XContentBuilder builder = JsonXContent.contentBuilder()) { - statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); - } + var builder = JsonXContent.contentBuilder(); + var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); assertEquals(RestStatus.BAD_REQUEST, statusException.status()); assertThat(statusException.toString(), containsString("Type [image_url] not supported by Google VertexAI ChatCompletion")); @@ -463,6 +459,7 @@ public void testParseAllFields() throws IOException { "parts": [ { "text": "some text" }, { "functionCall" : { + "id": "call_62136354", "name": "get_delivery_date", "args": { "order_id" : "order_12345" @@ -555,6 +552,7 @@ public void testParseFunctionCallNoContent() throws IOException { "role": "model", "parts": [ { "functionCall" : { + "id": "call_62136354", "name": "get_delivery_date", "args": { "order_id" : "order_12345" @@ -750,10 +748,9 @@ public void testParseToolChoiceInvalid_throwElasticSearchStatusException() throw UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); - ElasticsearchStatusException statusException; - try (XContentBuilder builder = JsonXContent.contentBuilder()) { - statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); - } + XContentBuilder builder = JsonXContent.contentBuilder(); + var statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); + assertThat( statusException.toString(), containsString("Tool choice value [unsupported] not supported by Google VertexAI ChatCompletion.") From c02012295607acba5b5f54b0385bc16d1c916921 Mon Sep 17 00:00:00 2001 From: Salvador Beltran Date: Fri, 16 May 2025 09:53:22 -0600 Subject: [PATCH 36/38] Unit tests Implemented basic unit testing. Will improve in the next commit. As of now, we want to find a way to mock certain parts of the initialization of the Google VertexAI service that trigger the authorization decorator, without using tools like powermock or changing too much the code. --- .../elastic/ElasticInferenceServiceTests.java | 55 ++++++++++++++++ .../GoogleVertexAiServiceTests.java | 63 +++++++++++++++++-- 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index ba92bf399f99c..f779b0247e79e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1299,6 +1299,61 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp } } + public void testUnifiedCompletionInfer_WithGoogleVertexAiModel() throws IOException { + var elasticInferenceServiceURL = getUrl(webServer); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + // Mock a successful streaming response + String responseJson = """ + data: {"id":"1","object":"completion","created":1677858242,"model":"my-model-id", + "choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":"Hello"}}]} + + data: {"id":"2","object":"completion","created":1677858242,"model":"my-model-id", + "choices":[{"finish_reason":"stop","index":0,"delta":{"content":" world!"}}]} + + data: [DONE] + + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // Create completion model + var model = new ElasticInferenceServiceCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings("gemini-2.0-flash-001", new RateLimitSettings(100)), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(elasticInferenceServiceURL) + ); + + var request = UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello"), "user", null, null)) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + + try { + service.unifiedCompletionInfer(model, request, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + // We don't need to check the actual response as we're only testing header propagation + listener.actionGet(TIMEOUT); + + // Verify the request was sent + assertThat(webServer.requests(), hasSize(1)); + var httpRequest = webServer.requests().getFirst(); + + // Check that the Gemini API was called. + assertThat(httpRequest.getBody().toString(), equalTo( "{\"messages\":[{\"content\":\"Hello\",\"role\":\"user\"}],\"n\":1,\"stream\":true,\"stream_options\":{\"include_usage\":true},\"model\":\"gemini-2.0-flash-001\"}")); + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + } + private void ensureAuthorizationCallFinished(ElasticInferenceService service) { service.onNodeStarted(); service.waitForFirstAuthorizationToComplete(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 8ea1c12ea9e4a..a16f1edf0a962 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; @@ -46,13 +47,17 @@ import java.util.HashMap; import java.util.Map; +import static java.util.concurrent.TimeUnit.MINUTES; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -78,6 +83,54 @@ public void shutdown() throws IOException { webServer.close(); } + public void testParseRequestConfig_CreateGoogleVertexAiChatCompletionModel() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiChatCompletionModel.class)); + + var vertexAIModel = (GoogleVertexAiChatCompletionModel) model; + + assertThat(vertexAIModel.getServiceSettings().modelId(), is(modelId)); + assertThat(vertexAIModel.getServiceSettings().location(), is(location)); + assertThat(vertexAIModel.getServiceSettings().projectId(), is(projectId)); + assertThat(vertexAIModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + assertThat(vertexAIModel.getConfigurations().getTaskType(), equalTo(CHAT_COMPLETION)); + assertThat(vertexAIModel.getServiceSettings().rateLimitSettings().requestsPerTimeUnit(), equalTo(1000L)); + assertThat(vertexAIModel.getServiceSettings().rateLimitSettings().timeUnit(), equalTo(MINUTES)); + + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(serviceAccountJson) + ), + modelListener + ); + } + } + public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { var projectId = "project"; var location = "location"; @@ -871,7 +924,7 @@ public void testGetConfiguration() throws Exception { { "service": "googlevertexai", "name": "Google Vertex AI", - "task_types": ["text_embedding", "rerank"], + "task_types": ["text_embedding", "rerank", "chat_completion"], "configurations": { "service_account_json": { "description": "API Key for the provider you're connecting to.", @@ -889,7 +942,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "location": { "description": "Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.", @@ -898,7 +951,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -907,7 +960,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] }, "model_id": { "description": "ID of the LLM you're using.", @@ -916,7 +969,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "chat_completion"] } } } From 86336590dc37e3018d4acdf5a63a569d5411df82 Mon Sep 17 00:00:00 2001 From: Salvador Beltran Date: Fri, 16 May 2025 10:06:58 -0600 Subject: [PATCH 37/38] Update ElasticInferenceServiceTests.java --- .../services/elastic/ElasticInferenceServiceTests.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index f779b0247e79e..32e90157cf9c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1318,7 +1318,7 @@ public void testUnifiedCompletionInfer_WithGoogleVertexAiModel() throws IOExcept webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - // Create completion model + // Create chat completion model var model = new ElasticInferenceServiceCompletionModel( "id", TaskType.CHAT_COMPLETION, @@ -1337,8 +1337,6 @@ public void testUnifiedCompletionInfer_WithGoogleVertexAiModel() throws IOExcept try { service.unifiedCompletionInfer(model, request, InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - // We don't need to check the actual response as we're only testing header propagation listener.actionGet(TIMEOUT); // Verify the request was sent From 06020cc6b3b3ca6078b0d04ed1725722b27fb9a9 Mon Sep 17 00:00:00 2001 From: Salvador Beltran Date: Fri, 16 May 2025 10:52:28 -0600 Subject: [PATCH 38/38] Update GoogleVertexAiServiceTests.java Implemented a test case for persisted config with secrets. --- .../GoogleVertexAiServiceTests.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 4d857e21e1dc9..3bb0a2adfbd49 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; @@ -491,6 +492,53 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM } } + public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatCompletionModel() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + getTaskSettingsMap(autoTruncate, InputType.INGEST), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.CHAT_COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiChatCompletionModel.class)); + + var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model; + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(chatCompletionModel.getServiceSettings().location(), is(location)); + assertThat(chatCompletionModel.getServiceSettings().projectId(), is(projectId)); + assertThat(chatCompletionModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + assertThat(chatCompletionModel.getConfigurations().getTaskType(), equalTo(CHAT_COMPLETION)); + assertThat(chatCompletionModel.getServiceSettings().rateLimitSettings().requestsPerTimeUnit(), equalTo(1000L)); + assertThat(chatCompletionModel.getServiceSettings().rateLimitSettings().timeUnit(), equalTo(MINUTES)); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { var projectId = "project"; var location = "location";