-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implemented completion task for Google VertexAI #128694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
a9df8e3
80af6d3
ca1b6d5
b6f5e34
ce6d45f
6cf0c0b
7eabd29
55d8650
bf27166
ab1fe7a
20f1914
30f53cd
5d20b35
2b865af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai; | ||
|
|
||
| import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; | ||
|
|
||
| public class GoogleVertexAiChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { | ||
|
|
||
| public GoogleVertexAiChatCompletionResponseHandler(String requestType) { | ||
| super( | ||
| requestType, | ||
| GoogleVertexAiCompletionResponseEntity::fromResponse, | ||
| GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse, | ||
| true | ||
| ); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xcontent.XContentParser; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; | ||
| import org.elasticsearch.xpack.inference.common.DelegatingProcessor; | ||
| import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Deque; | ||
| import java.util.Objects; | ||
| import java.util.stream.Stream; | ||
|
|
||
| import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
| import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; | ||
|
|
||
| public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> { | ||
|
|
||
| @Override | ||
| protected void next(Deque<ServerSentEvent> item) throws Exception { | ||
| var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); | ||
| var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig); | ||
|
|
||
| if (results.isEmpty()) { | ||
| upstream().request(1); | ||
| } else { | ||
| downstream().onNext(new StreamingChatCompletionResults.Results(results)); | ||
| } | ||
| } | ||
|
|
||
| public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) { | ||
| String data = event.data(); | ||
| try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { | ||
| moveToFirstToken(jsonParser); | ||
|
||
| ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); | ||
|
|
||
| var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser); | ||
|
|
||
| return chunk.choices() | ||
| .stream() | ||
| .map(choice -> choice.delta()) | ||
| .filter(Objects::nonNull) | ||
| .map(delta -> delta.content()) | ||
| .filter(content -> content != null && content.isEmpty() == false) | ||
|
||
| .map(StreamingChatCompletionResults.Result::new); | ||
|
|
||
| } catch (IOException e) { | ||
| throw new ElasticsearchStatusException( | ||
| "Failed to parse event from inference provider: {}", | ||
| RestStatus.INTERNAL_SERVER_ERROR, | ||
| e, | ||
| event | ||
| ); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.action; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.xpack.inference.external.action.ExecutableAction; | ||
| import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; | ||
| import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; | ||
|
|
@@ -16,14 +18,17 @@ | |
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
| import org.elasticsearch.xpack.inference.services.ServiceComponents; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiChatCompletionResponseHandler; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; | ||
|
|
||
| import java.net.URISyntaxException; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
|
|
||
|
|
@@ -36,9 +41,10 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor | |
|
|
||
| private final ServiceComponents serviceComponents; | ||
|
|
||
| static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( | ||
| static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( | ||
|
||
| "Google VertexAI chat completion" | ||
| ); | ||
| static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiChatCompletionResponseHandler("Google VertexAI completion"); | ||
|
||
| static final String USER_ROLE = "user"; | ||
|
|
||
| public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { | ||
|
|
@@ -72,11 +78,31 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri | |
| var manager = new GenericRequestManager<>( | ||
| serviceComponents.threadPool(), | ||
| model, | ||
| COMPLETION_HANDLER, | ||
| UNIFIED_CHAT_COMPLETION_HANDLER, | ||
| inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), | ||
| ChatCompletionInput.class | ||
| ); | ||
|
|
||
| return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); | ||
| } | ||
|
|
||
| @Override | ||
| public ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings) { | ||
| var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); | ||
|
|
||
| var manager = new GenericRequestManager<>(serviceComponents.threadPool(), model, CHAT_COMPLETION_HANDLER, inputs -> { | ||
| try { | ||
| model.updateUri(inputs.stream()); | ||
| } catch (URISyntaxException e) { | ||
| throw new ElasticsearchStatusException( | ||
| "Error constructing URI for Google VertexAI completion", | ||
| RestStatus.INTERNAL_SERVER_ERROR, | ||
| e | ||
| ); | ||
| } | ||
| return new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model); | ||
|
||
| }, ChatCompletionInput.class); | ||
|
|
||
| return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,9 @@ | |
|
|
||
| import org.apache.http.client.utils.URIBuilder; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.xpack.inference.external.action.ExecutableAction; | ||
| import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; | ||
|
|
||
| import java.net.URI; | ||
|
|
@@ -39,6 +41,25 @@ public GoogleVertexAiCompletionModel( | |
|
|
||
| } | ||
|
|
||
| public void updateUri(boolean isStream) throws URISyntaxException { | ||
|
||
| var location = getServiceSettings().location(); | ||
| var projectId = getServiceSettings().projectId(); | ||
| var model = getServiceSettings().modelId(); | ||
|
|
||
| // Google VertexAI generates streaming response using another API. We call this | ||
| // method before making the request to be sure we are calling the right API | ||
| if (isStream) { | ||
| this.uri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); | ||
| } else { | ||
| this.uri = GoogleVertexAiCompletionModel.buildUri(location, projectId, model); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) { | ||
| return visitor.create(this, taskSettings); | ||
| } | ||
|
|
||
| public static URI buildUri(String location, String projectId, String model) throws URISyntaxException { | ||
| return new URIBuilder().setScheme("https") | ||
| .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if I should keep this class or delete it and use the base class instead.
Also, I am using
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class betweenGoogleVertexAiUnifiedChatCompletionResponseHandlerandGoogleVertexAiChatCompletionResponseHandlerto avoid extending the class hierarchy, but let me know if you think otherwiseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to you, I'd probably use the base class. If you want to keep this one, then I don't think we need to accept the
requestType. I think we can set it in this class directly.Nice! That looks good.