-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implemented completion task for Google VertexAI #128694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
a9df8e3
80af6d3
ca1b6d5
b6f5e34
ce6d45f
6cf0c0b
7eabd29
55d8650
bf27166
ab1fe7a
20f1914
30f53cd
5d20b35
2b865af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 128694 | ||
| summary: "Adding Google VertexAI completion integration" | ||
| area: Inference | ||
| type: enhancement | ||
| issues: [ ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,10 +23,10 @@ | |
| import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; | ||
| import org.elasticsearch.xpack.inference.external.request.Request; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; | ||
|
|
||
| import java.nio.charset.StandardCharsets; | ||
| import java.util.Locale; | ||
|
|
@@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe | |
| private static final String ERROR_MESSAGE_FIELD = "message"; | ||
| private static final String ERROR_STATUS_FIELD = "status"; | ||
|
|
||
| private static final ResponseParser noopParseFunction = (a, b) -> null; | ||
|
|
||
| public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) { | ||
| super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true); | ||
| super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To support streaming and non-streaming for For example take a look at openai: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java#L133-L140 These are the changes I think we need to make:
Let me know if you'd rather jump on a call to discuss this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry one correction, the way we report errors for the "unified" aka openai format is different from the "elasticsearch" way of return errors. So for streaming We might need to do some refactoring but I would see if you could reuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I managed to get this working as you suggested. The only hacky thing is that I have to add a method to the completion model |
||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -63,7 +61,6 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR | |
|
|
||
| @Override | ||
| protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { | ||
| assert request.isStreaming() : "Only streaming requests support this format"; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try to keep this for streaming only |
||
| var responseStatusCode = result.response().getStatusLine().getStatusCode(); | ||
| var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); | ||
| var restStatus = toRestStatus(responseStatusCode); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.completion; | ||
|
|
||
| import org.apache.http.client.utils.URIBuilder; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; | ||
|
|
||
| import java.net.URI; | ||
| import java.net.URISyntaxException; | ||
| import java.util.Map; | ||
|
|
||
| import static org.elasticsearch.core.Strings.format; | ||
|
|
||
| public class GoogleVertexAiCompletionModel extends GoogleVertexAiChatCompletionModel { | ||
|
|
||
| public GoogleVertexAiCompletionModel( | ||
| String inferenceEntityId, | ||
| TaskType taskType, | ||
| String service, | ||
| Map<String, Object> serviceSettings, | ||
| Map<String, Object> taskSettings, | ||
| Map<String, Object> secrets, | ||
| ConfigurationParseContext context | ||
| ) { | ||
| super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secrets, context); | ||
| try { | ||
| var modelServiceSettings = this.getServiceSettings(); | ||
| this.uri = buildUri(modelServiceSettings.location(), modelServiceSettings.projectId(), modelServiceSettings.modelId()); | ||
| } catch (URISyntaxException e) { | ||
| throw new RuntimeException(e); | ||
| } | ||
|
|
||
| } | ||
|
|
||
| public static URI buildUri(String location, String projectId, String model) throws URISyntaxException { | ||
| return new URIBuilder().setScheme("https") | ||
| .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) | ||
| .setPathSegments( | ||
| GoogleVertexAiUtils.V1, | ||
| GoogleVertexAiUtils.PROJECTS, | ||
| projectId, | ||
| GoogleVertexAiUtils.LOCATIONS, | ||
| GoogleVertexAiUtils.GLOBAL, | ||
| GoogleVertexAiUtils.PUBLISHERS, | ||
| GoogleVertexAiUtils.PUBLISHER_GOOGLE, | ||
| GoogleVertexAiUtils.MODELS, | ||
| format("%s:%s", model, GoogleVertexAiUtils.GENERATE_CONTENT) | ||
|
||
| ) | ||
| .build(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.response; | ||
|
|
||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xcontent.XContentParser; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; | ||
| import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
| import org.elasticsearch.xpack.inference.external.request.Request; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedStreamingProcessor; | ||
|
|
||
| import java.io.IOException; | ||
| import java.nio.charset.StandardCharsets; | ||
|
|
||
| public class GoogleVertexAiCompletionResponseEntity { | ||
| /** | ||
| * Parses the response from Google Vertex AI's generateContent endpoint | ||
| * For a request like: | ||
| * <pre> | ||
| * <code> | ||
| * { | ||
| * "inputs": "Please summarize this text: some text" | ||
| * } | ||
| * </code> | ||
| * </pre> | ||
| * | ||
| * The response is a <a href="https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse">GenerateContentResponse</a> objects that looks like: | ||
| * | ||
| * <pre> | ||
| * <code> | ||
| * | ||
| * { | ||
| * "candidates": [ | ||
| * { | ||
| * "content": { | ||
| * "role": "model", | ||
| * "parts": [ | ||
| * { | ||
| * "text": "I am sorry, I cannot summarize the text because I do not have access to the text you are referring to." | ||
| * } | ||
| * ] | ||
| * }, | ||
| * "finishReason": "STOP", | ||
| * "avgLogprobs": -0.19326641248620074 | ||
| * } | ||
| * ], | ||
| * "usageMetadata": { | ||
| * "promptTokenCount": 71, | ||
| * "candidatesTokenCount": 23, | ||
| * "totalTokenCount": 94, | ||
| * "trafficType": "ON_DEMAND", | ||
| * "promptTokensDetails": [ | ||
| * { | ||
| * "modality": "TEXT", | ||
| * "tokenCount": 71 | ||
| * } | ||
| * ], | ||
| * "candidatesTokensDetails": [ | ||
| * { | ||
| * "modality": "TEXT", | ||
| * "tokenCount": 23 | ||
| * } | ||
| * ] | ||
| * }, | ||
| * "modelVersion": "gemini-2.0-flash-001", | ||
| * "createTime": "2025-05-28T15:08:20.049493Z", | ||
| * "responseId": "5CY3aNWCA6mm4_UPr-zduAE" | ||
| * } | ||
| * </code> | ||
| * </pre> | ||
| * | ||
| * @param request The original request made to the service. | ||
| **/ | ||
| public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { | ||
| var responseJson = new String(response.body(), StandardCharsets.UTF_8); | ||
|
|
||
| // Response from generateContent has the same shape as streamGenerateContent. We reuse the already implemented | ||
| // class to avoid code duplication | ||
|
|
||
| StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk; | ||
| try ( | ||
| XContentParser parser = XContentFactory.xContent(XContentType.JSON) | ||
| .createParser(XContentParserConfiguration.EMPTY, responseJson) | ||
| ) { | ||
| parser.nextToken(); | ||
|
||
| chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); | ||
| } | ||
| var results = chunk.choices().stream().map(choice -> choice.delta().content()).map(ChatCompletionResults.Result::new).toList(); | ||
|
|
||
| return new ChatCompletionResults(results); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.completion; | ||
|
|
||
| import org.elasticsearch.core.Strings; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.test.ESTestCase; | ||
| import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; | ||
|
|
||
| import java.net.URI; | ||
| import java.net.URISyntaxException; | ||
| import java.util.HashMap; | ||
| import java.util.Map; | ||
|
|
||
| import static org.hamcrest.Matchers.*; | ||
|
|
||
| public class GoogleVertexAiCompletionModelTests extends ESTestCase { | ||
|
|
||
| private static final String DEFAULT_PROJECT_ID = "test-project"; | ||
| private static final String DEFAULT_LOCATION = "us-central1"; | ||
| private static final String DEFAULT_MODEL_ID = "gemini-pro"; | ||
|
|
||
| public void testCreateModel() throws URISyntaxException { | ||
| var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID); | ||
| URI expectedUri = new URI( | ||
| Strings.format( | ||
| "https://%s-aiplatform.googleapis.com/v1/projects/%s" + "/locations/global/publishers/google/models/%s:generateContent", | ||
| DEFAULT_LOCATION, | ||
| DEFAULT_PROJECT_ID, | ||
| DEFAULT_MODEL_ID | ||
|
|
||
| ) | ||
| ); | ||
| assertThat(model.uri(), is(expectedUri)); | ||
| } | ||
|
|
||
| private static GoogleVertexAiCompletionModel createCompletionModel(String projectId, String location, String modelId) { | ||
| return new GoogleVertexAiCompletionModel( | ||
| "google-vertex-ai-chat-test-id", | ||
| TaskType.CHAT_COMPLETION, | ||
| "google_vertex_ai", | ||
| new HashMap<>(Map.of("project_id", projectId, "location", location, "model_id", modelId)), | ||
| new HashMap<>(), | ||
| new HashMap<>(Map.of("service_account_json", "{}")), | ||
| ConfigurationParseContext.PERSISTENT | ||
| ); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, we can't rely on
modelbeing only aGoogleVertexAiCompletionModel, it could be a completion, rerank, or embedding model. Do we need it as a completion model here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. My understanding was this method was only called by the
completiontask, that's why I was downcasting it. Will revert this change