-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implemented ChatCompletion task for Google VertexAI with Gemini Models #128105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
00a6636
9be2a44
c2387e8
50770ea
42cbbe2
fe8e336
7c24f93
2140d05
42dd376
0863316
f080e96
8f6648f
848dc7a
59862c6
93a7ca7
7b99b1d
c05655f
acc864f
c371073
efb90ba
bb68715
1ead8c5
ad9f0e1
f4057f3
38b9ca4
2e8dbee
ddd19c5
88a2780
b841e4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 128105 | ||
| summary: "Google VertexAI integration now supports chat_completion task" | ||
| area: Inference | ||
| type: enhancement | ||
| issues: [ ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,6 +254,7 @@ static TransportVersion def(int id) { | |
| public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); | ||
| public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); | ||
| public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00); | ||
| public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_078_0_00); | ||
|
||
|
|
||
| /* | ||
| * STOP! READ THIS FIRST! No, really, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.external.response.streaming; | ||
|
|
||
| import java.io.ByteArrayOutputStream; | ||
| import java.io.IOException; | ||
| import java.io.UncheckedIOException; | ||
| import java.util.ArrayDeque; | ||
| import java.util.Arrays; | ||
| import java.util.Deque; | ||
|
|
||
| /** | ||
| * Parses a stream of bytes that form a JSON array, where each element of the array | ||
| * is a JSON object. This parser extracts each complete JSON object from the array | ||
| * and emits it as byte array. | ||
| * | ||
| * Example of an expected stream: | ||
| * Chunk 1: [{"key":"val1"} | ||
| * Chunk 2: ,{"key2":"val2"} | ||
| * Chunk 3: ,{"key3":"val3"}, {"some":"object"}] | ||
| * | ||
| * This parser would emit four byte arrays, with data: | ||
| * 1. {"key":"val1"} | ||
| * 2. {"key2":"val2"} | ||
| * 3. {"key3":"val3"} | ||
| * 4. {"some":"object"} | ||
| */ | ||
| public class JsonArrayPartsEventParser { | ||
|
|
||
| // Buffer to hold bytes from the previous call if they formed an incomplete JSON object. | ||
| private final ByteArrayOutputStream incompletePart = new ByteArrayOutputStream(); | ||
|
|
||
| public Deque<byte[]> parse(byte[] newBytes) { | ||
| if (newBytes == null || newBytes.length == 0) { | ||
| return new ArrayDeque<>(0); | ||
| } | ||
|
|
||
| ByteArrayOutputStream currentStream = new ByteArrayOutputStream(); | ||
| try { | ||
| currentStream.write(incompletePart.toByteArray()); | ||
| currentStream.write(newBytes); | ||
| } catch (IOException e) { | ||
| throw new UncheckedIOException("Error handling byte array streams", e); | ||
| } | ||
| incompletePart.reset(); | ||
|
|
||
| byte[] dataToProcess = currentStream.toByteArray(); | ||
| return parseInternal(dataToProcess); | ||
| } | ||
|
|
||
| private Deque<byte[]> parseInternal(byte[] data) { | ||
| int localBraceLevel = 0; | ||
| int objectStartIndex = -1; | ||
| Deque<byte[]> completedObjects = new ArrayDeque<>(); | ||
|
|
||
| for (int i = 0; i < data.length; i++) { | ||
| char c = (char) data[i]; | ||
|
|
||
| if (c == '{') { | ||
| if (localBraceLevel == 0) { | ||
| objectStartIndex = i; | ||
| } | ||
| localBraceLevel++; | ||
| } else if (c == '}') { | ||
| if (localBraceLevel > 0) { | ||
| localBraceLevel--; | ||
| if (localBraceLevel == 0) { | ||
| byte[] jsonObject = Arrays.copyOfRange(data, objectStartIndex, i + 1); | ||
| completedObjects.offer(jsonObject); | ||
| objectStartIndex = -1; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (localBraceLevel > 0) { | ||
| incompletePart.write(data, objectStartIndex, data.length - objectStartIndex); | ||
| } | ||
| return completedObjects; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.external.response.streaming; | ||
|
|
||
| import org.elasticsearch.xpack.inference.common.DelegatingProcessor; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
|
|
||
| import java.util.Deque; | ||
|
|
||
| public class JsonArrayPartsEventProcessor extends DelegatingProcessor<HttpResult, Deque<byte[]>> { | ||
| private final JsonArrayPartsEventParser jsonArrayPartsEventParser; | ||
|
|
||
| public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEventParser) { | ||
| this.jsonArrayPartsEventParser = jsonArrayPartsEventParser; | ||
| } | ||
|
|
||
| @Override | ||
| public void next(HttpResult item) { | ||
| if (item.isBodyEmpty()) { | ||
| upstream().request(1); | ||
| return; | ||
| } | ||
|
|
||
| var response = jsonArrayPartsEventParser.parse(item.body()); | ||
| if (response.isEmpty()) { | ||
| upstream().request(1); | ||
| return; | ||
| } | ||
|
|
||
| downstream().onNext(response); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai; | ||
|
|
||
| import org.apache.logging.log4j.LogManager; | ||
| import org.apache.logging.log4j.Logger; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.threadpool.ThreadPool; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity; | ||
|
|
||
| import java.util.Objects; | ||
| import java.util.function.Supplier; | ||
|
|
||
| public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager { | ||
|
||
|
|
||
| private static final Logger logger = LogManager.getLogger(GoogleVertexAiCompletionRequestManager.class); | ||
|
|
||
| private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler(); | ||
|
|
||
| private static ResponseHandler createGoogleVertexAiResponseHandler() { | ||
| return new GoogleVertexAiUnifiedChatCompletionResponseHandler( | ||
| "Google Vertex AI chat completion", | ||
| GoogleVertexAiChatCompletionResponseEntity::fromResponse | ||
|
||
| ); | ||
| } | ||
|
|
||
| private final GoogleVertexAiChatCompletionModel model; | ||
|
|
||
| public GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) { | ||
| super(threadPool, model, RateLimitGrouping.of(model)); | ||
| this.model = model; | ||
| } | ||
|
|
||
| record RateLimitGrouping(int projectIdHash) { | ||
| public static RateLimitGrouping of(GoogleVertexAiChatCompletionModel model) { | ||
| Objects.requireNonNull(model); | ||
| return new RateLimitGrouping(model.rateLimitServiceSettings().projectId().hashCode()); | ||
| } | ||
| } | ||
|
|
||
| public static GoogleVertexAiCompletionRequestManager of(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) { | ||
| Objects.requireNonNull(model); | ||
| Objects.requireNonNull(threadPool); | ||
|
|
||
| return new GoogleVertexAiCompletionRequestManager(model, threadPool); | ||
| } | ||
|
|
||
| @Override | ||
| public void execute( | ||
| InferenceInputs inferenceInputs, | ||
| RequestSender requestSender, | ||
| Supplier<Boolean> hasRequestCompletedFunction, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
|
|
||
| var chatInputs = (UnifiedChatInput) inferenceInputs; | ||
| var request = new GoogleVertexAiUnifiedChatCompletionRequest(chatInputs, model); | ||
| execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.