-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implemented completion task for Google VertexAI #128694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jonathan-buttner
merged 14 commits into
elastic:main
from
leo-hoet:google-vertexai-completion
Jun 9, 2025
Merged
Changes from 11 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a9df8e3
Google Vertex AI completion model, response entity and tests
leo-hoet 80af6d3
Fixed GoogleVertexAiServiceTest for Service configuration
leo-hoet ca1b6d5
Changelog
leo-hoet b6f5e34
Removed downcasting and using `moveToFirstToken`
leo-hoet ce6d45f
Create GoogleVertexAiChatCompletionResponseHandler for streaming and …
leo-hoet 6cf0c0b
Added unit tests
leo-hoet 7eabd29
PR feedback
leo-hoet 55d8650
Removed googlevertexaicompletion model. Using just GoogleVertexAiChat…
leo-hoet bf27166
Renamed uri -> nonStreamingUri. Added streamingUri and getters in Goo…
leo-hoet ab1fe7a
Moved rateLimitGroupHashing to subclasses of GoogleVertexAiModel
leo-hoet 20f1914
Merge branch 'main' into google-vertexai-completion
lhoet-google 30f53cd
Fixed rate limit has of GoogleVertexAiRerankModel and refactored uri …
leo-hoet 5d20b35
Merge branch 'google-vertexai-completion' of github.com:leo-hoet/elas…
leo-hoet 2b865af
Merge branch 'main' into google-vertexai-completion
jonathan-buttner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 128694 | ||
| summary: "Adding Google VertexAI completion integration" | ||
| area: Inference | ||
| type: enhancement | ||
| issues: [ ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
...asticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessor.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xcontent.XContentParser; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; | ||
| import org.elasticsearch.xpack.inference.common.DelegatingProcessor; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Deque; | ||
| import java.util.Objects; | ||
| import java.util.stream.Stream; | ||
|
|
||
| public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> { | ||
|
|
||
| @Override | ||
| protected void next(Deque<ServerSentEvent> item) throws Exception { | ||
| var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); | ||
| var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig); | ||
|
|
||
| if (results.isEmpty()) { | ||
| upstream().request(1); | ||
| } else { | ||
| downstream().onNext(new StreamingChatCompletionResults.Results(results)); | ||
| } | ||
| } | ||
|
|
||
| public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) { | ||
| String data = event.data(); | ||
| try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { | ||
| var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser); | ||
|
|
||
| return chunk.choices() | ||
| .stream() | ||
| .map(choice -> choice.delta()) | ||
| .filter(Objects::nonNull) | ||
| .map(delta -> delta.content()) | ||
| .filter(content -> Strings.isNullOrEmpty(content) == false) | ||
| .map(StreamingChatCompletionResults.Result::new); | ||
|
|
||
| } catch (IOException e) { | ||
| throw new ElasticsearchStatusException( | ||
| "Failed to parse event from inference provider: {}", | ||
| RestStatus.INTERNAL_SERVER_ERROR, | ||
| e, | ||
| event | ||
| ); | ||
| } | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,10 +23,10 @@ | |
| import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; | ||
| import org.elasticsearch.xpack.inference.external.request.Request; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; | ||
|
|
||
| import java.nio.charset.StandardCharsets; | ||
| import java.util.Locale; | ||
|
|
@@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe | |
| private static final String ERROR_MESSAGE_FIELD = "message"; | ||
| private static final String ERROR_STATUS_FIELD = "status"; | ||
|
|
||
| private static final ResponseParser noopParseFunction = (a, b) -> null; | ||
|
|
||
| public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) { | ||
| super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true); | ||
| super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -64,6 +62,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR | |
| @Override | ||
| protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { | ||
| assert request.isStreaming() : "Only streaming requests support this format"; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try to keep this for streaming only |
||
|
|
||
| var responseStatusCode = result.response().getStatusLine().getStatusCode(); | ||
| var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); | ||
| var restStatus = toRestStatus(responseStatusCode); | ||
|
|
@@ -111,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex | |
| } | ||
| } | ||
|
|
||
| private static class GoogleVertexAiErrorResponse extends ErrorResponse { | ||
| public static class GoogleVertexAiErrorResponse extends ErrorResponse { | ||
| private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class); | ||
| private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>( | ||
| "google_vertex_ai_error_wrapper", | ||
|
|
@@ -138,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse { | |
| ); | ||
| } | ||
|
|
||
| static ErrorResponse fromResponse(HttpResult response) { | ||
| public static ErrorResponse fromResponse(HttpResult response) { | ||
| try ( | ||
| XContentParser parser = XContentFactory.xContent(XContentType.JSON) | ||
| .createParser(XContentParserConfiguration.EMPTY, response.body()) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To support streaming and non-streaming for
completionI think we'll need a slightly different inheritance hierarchy.For example take a look at openai: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java#L133-L140
These are the changes I think we need to make:
parseResultsinGoogleVertexAiResponseHandler, it can be identical toGoogleVertexAiUnifiedChatCompletionResponseHandler, except we'll need to returnStreamingChatCompletionResultsinstead of the unified versionbuildMidstreamErroror some other functionality from the unified response handler up to a class that both completion and chat completion extend that might be betterGoogleVertexAiActionCreatorwe'll create a new response handler that leveragesGoogleVertexAiCompletionResponseEntity::fromResponsefor the non-streaming caseLet me know if you'd rather jump on a call to discuss this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry one correction, the way we report errors for the "unified" aka openai format is different from the "elasticsearch" way of return errors. So for streaming
completionwe don't want to follow what we're doing inGoogleVertexAiUnifiedChatCompletionResponseHandlerbecause that is returning the errors in openai format. I would try to follow what we're doing in the in the link to theOpenAiResponseHandlerthat I linked. Hopefully we don't need to create a whole new streaming processor though.We might need to do some refactoring but I would see if you could reuse
GoogleVertexAiUnifiedStreamingProcessorfor the parsing logic but we'll need to return a different result (specificallyStreamingChatCompletionResults).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I managed to get this working as you suggested. The only hacky thing is that I have to add a method to the completion model
updateUrithat is called before making the request to ensure we are calling the right api. Take a look a let me know what you think. (It's still missing unit tests)