-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implemented completion task for Google VertexAI #128694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
a9df8e3
80af6d3
ca1b6d5
b6f5e34
ce6d45f
6cf0c0b
7eabd29
55d8650
bf27166
ab1fe7a
20f1914
30f53cd
5d20b35
2b865af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 128694 | ||
| summary: "Adding Google VertexAI completion integration" | ||
| area: Inference | ||
| type: enhancement | ||
| issues: [ ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai; | ||
|
|
||
| import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; | ||
|
|
||
| public class GoogleVertexAiChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { | ||
|
|
||
| public GoogleVertexAiChatCompletionResponseHandler(String requestType) { | ||
| super( | ||
| requestType, | ||
| GoogleVertexAiCompletionResponseEntity::fromResponse, | ||
| GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse, | ||
| true | ||
| ); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xcontent.XContentParser; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; | ||
| import org.elasticsearch.xpack.inference.common.DelegatingProcessor; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Deque; | ||
| import java.util.Objects; | ||
| import java.util.stream.Stream; | ||
|
|
||
| import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
| import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; | ||
|
|
||
| public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> { | ||
|
|
||
| @Override | ||
| protected void next(Deque<ServerSentEvent> item) throws Exception { | ||
| var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); | ||
| var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig); | ||
|
|
||
| if (results.isEmpty()) { | ||
| upstream().request(1); | ||
| } else { | ||
| downstream().onNext(new StreamingChatCompletionResults.Results(results)); | ||
| } | ||
| } | ||
|
|
||
| public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) { | ||
| String data = event.data(); | ||
| try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { | ||
| moveToFirstToken(jsonParser); | ||
|
||
| ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); | ||
|
|
||
| var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser); | ||
|
|
||
| return chunk.choices() | ||
| .stream() | ||
| .map(choice -> choice.delta()) | ||
| .filter(Objects::nonNull) | ||
| .map(delta -> delta.content()) | ||
| .filter(content -> content != null && content.isEmpty() == false) | ||
|
||
| .map(StreamingChatCompletionResults.Result::new); | ||
|
|
||
| } catch (IOException e) { | ||
| throw new ElasticsearchStatusException( | ||
| "Failed to parse event from inference provider: {}", | ||
| RestStatus.INTERNAL_SERVER_ERROR, | ||
| e, | ||
| event | ||
| ); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,10 +23,10 @@ | |
| import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; | ||
| import org.elasticsearch.xpack.inference.external.request.Request; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; | ||
|
|
||
| import java.nio.charset.StandardCharsets; | ||
| import java.util.Locale; | ||
|
|
@@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe | |
| private static final String ERROR_MESSAGE_FIELD = "message"; | ||
| private static final String ERROR_STATUS_FIELD = "status"; | ||
|
|
||
| private static final ResponseParser noopParseFunction = (a, b) -> null; | ||
|
|
||
| public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) { | ||
| super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true); | ||
| super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To support streaming and non-streaming for For example take a look at openai: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java#L133-L140 These are the changes I think we need to make:
Let me know if you'd rather jump on a call to discuss this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry one correction, the way we report errors for the "unified" aka openai format is different from the "elasticsearch" way of return errors. So for streaming We might need to do some refactoring but I would see if you could reuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I managed to get this working as you suggested. The only hacky thing is that I have to add a method to the completion model |
||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -64,6 +62,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR | |
| @Override | ||
| protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { | ||
| assert request.isStreaming() : "Only streaming requests support this format"; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try to keep this for streaming only |
||
|
|
||
| var responseStatusCode = result.response().getStatusLine().getStatusCode(); | ||
| var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); | ||
| var restStatus = toRestStatus(responseStatusCode); | ||
|
|
@@ -111,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex | |
| } | ||
| } | ||
|
|
||
| private static class GoogleVertexAiErrorResponse extends ErrorResponse { | ||
| public static class GoogleVertexAiErrorResponse extends ErrorResponse { | ||
| private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class); | ||
| private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>( | ||
| "google_vertex_ai_error_wrapper", | ||
|
|
@@ -138,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse { | |
| ); | ||
| } | ||
|
|
||
| static ErrorResponse fromResponse(HttpResult response) { | ||
| public static ErrorResponse fromResponse(HttpResult response) { | ||
| try ( | ||
| XContentParser parser = XContentFactory.xContent(XContentType.JSON) | ||
| .createParser(XContentParserConfiguration.EMPTY, response.body()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.action; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.xpack.inference.external.action.ExecutableAction; | ||
| import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; | ||
| import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; | ||
|
|
@@ -16,14 +18,17 @@ | |
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
| import org.elasticsearch.xpack.inference.services.ServiceComponents; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiChatCompletionResponseHandler; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; | ||
|
|
||
| import java.net.URISyntaxException; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
|
|
||
|
|
@@ -36,9 +41,10 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor | |
|
|
||
| private final ServiceComponents serviceComponents; | ||
|
|
||
| static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( | ||
| static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( | ||
|
||
| "Google VertexAI chat completion" | ||
| ); | ||
| static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiChatCompletionResponseHandler("Google VertexAI completion"); | ||
|
||
| static final String USER_ROLE = "user"; | ||
|
|
||
| public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { | ||
|
|
@@ -72,11 +78,31 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri | |
| var manager = new GenericRequestManager<>( | ||
| serviceComponents.threadPool(), | ||
| model, | ||
| COMPLETION_HANDLER, | ||
| UNIFIED_CHAT_COMPLETION_HANDLER, | ||
| inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), | ||
| ChatCompletionInput.class | ||
| ); | ||
|
|
||
| return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); | ||
| } | ||
|
|
||
| @Override | ||
| public ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings) { | ||
| var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); | ||
|
|
||
| var manager = new GenericRequestManager<>(serviceComponents.threadPool(), model, CHAT_COMPLETION_HANDLER, inputs -> { | ||
| try { | ||
| model.updateUri(inputs.stream()); | ||
| } catch (URISyntaxException e) { | ||
| throw new ElasticsearchStatusException( | ||
| "Error constructing URI for Google VertexAI completion", | ||
| RestStatus.INTERNAL_SERVER_ERROR, | ||
| e | ||
| ); | ||
| } | ||
| return new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model); | ||
|
||
| }, ChatCompletionInput.class); | ||
|
|
||
| return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if I should keep this class or delete it and use the base class instead.
Also, I am using
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class betweenGoogleVertexAiUnifiedChatCompletionResponseHandlerandGoogleVertexAiChatCompletionResponseHandlerto avoid extending the class hierarchy, but let me know if you think otherwiseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to you, I'd probably use the base class. If you want to keep this one, then I don't think we need to accept the
requestType. I think we can set it in this class directly.Nice! That looks good.