Skip to content

Commit 00a6636

Browse files
committed
Implemented ChatCompletion task for Google VertexAI with Gemini Models
1 parent e586a01 commit 00a6636

File tree

27 files changed

+3778
-13
lines changed

27 files changed

+3778
-13
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ static TransportVersion def(int id) {
254254
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
255255
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
256256
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
257+
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_078_0_00);
257258

258259
/*
259260
* STOP! READ THIS FIRST! No, really,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.streaming;
9+
10+
import java.io.ByteArrayOutputStream;
11+
import java.io.IOException;
12+
import java.io.UncheckedIOException;
13+
import java.util.ArrayDeque;
14+
import java.util.Arrays;
15+
import java.util.Deque;
16+
17+
/**
18+
* Parses a stream of bytes that form a JSON array, where each element of the array
19+
* is a JSON object. This parser extracts each complete JSON object from the array
20+
* and emits it as byte array.
21+
*
22+
* Example of an expected stream:
23+
* Chunk 1: [{"key":"val1"}
24+
* Chunk 2: ,{"key2":"val2"}
25+
* Chunk 3: ,{"key3":"val3"}, {"some":"object"}]
26+
*
27+
* This parser would emit four byte arrays, with data:
28+
* 1. {"key":"val1"}
29+
* 2. {"key2":"val2"}
30+
* 3. {"key3":"val3"}
31+
* 4. {"some":"object"}
32+
*/
33+
public class JsonArrayPartsEventParser {
34+
35+
// Buffer to hold bytes from the previous call if they formed an incomplete JSON object.
36+
private final ByteArrayOutputStream incompletePart = new ByteArrayOutputStream();
37+
38+
public Deque<byte[]> parse(byte[] newBytes) {
39+
if (newBytes == null || newBytes.length == 0) {
40+
return new ArrayDeque<>(0);
41+
}
42+
43+
ByteArrayOutputStream currentStream = new ByteArrayOutputStream();
44+
try {
45+
currentStream.write(incompletePart.toByteArray());
46+
currentStream.write(newBytes);
47+
} catch (IOException e) {
48+
throw new UncheckedIOException("Error handling byte array streams", e);
49+
}
50+
incompletePart.reset();
51+
52+
byte[] dataToProcess = currentStream.toByteArray();
53+
return parseInternal(dataToProcess);
54+
}
55+
56+
private Deque<byte[]> parseInternal(byte[] data) {
57+
int localBraceLevel = 0;
58+
int objectStartIndex = -1;
59+
Deque<byte[]> completedObjects = new ArrayDeque<>();
60+
61+
for (int i = 0; i < data.length; i++) {
62+
char c = (char) data[i];
63+
64+
if (c == '{') {
65+
if (localBraceLevel == 0) {
66+
objectStartIndex = i;
67+
}
68+
localBraceLevel++;
69+
} else if (c == '}') {
70+
if (localBraceLevel > 0) {
71+
localBraceLevel--;
72+
if (localBraceLevel == 0) {
73+
byte[] jsonObject = Arrays.copyOfRange(data, objectStartIndex, i + 1);
74+
completedObjects.offer(jsonObject);
75+
objectStartIndex = -1;
76+
}
77+
}
78+
}
79+
}
80+
81+
if (localBraceLevel > 0) {
82+
incompletePart.write(data, objectStartIndex, data.length - objectStartIndex);
83+
}
84+
return completedObjects;
85+
}
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.streaming;
9+
10+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
11+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
13+
import java.util.Deque;
14+
15+
public class JsonArrayPartsEventProcessor extends DelegatingProcessor<HttpResult, Deque<byte[]>> {
16+
private final JsonArrayPartsEventParser jsonArrayPartsEventParser;
17+
18+
public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEventParser) {
19+
this.jsonArrayPartsEventParser = jsonArrayPartsEventParser;
20+
}
21+
22+
@Override
23+
public void next(HttpResult item) {
24+
if (item.isBodyEmpty()) {
25+
upstream().request(1);
26+
return;
27+
}
28+
29+
var response = jsonArrayPartsEventParser.parse(item.body());
30+
if (response.isEmpty()) {
31+
upstream().request(1);
32+
return;
33+
}
34+
35+
downstream().onNext(response);
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
18+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
19+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
20+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
21+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
22+
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity;
23+
24+
import java.util.Objects;
25+
import java.util.function.Supplier;
26+
27+
public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager {
28+
29+
private static final Logger logger = LogManager.getLogger(GoogleVertexAiCompletionRequestManager.class);
30+
31+
private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler();
32+
33+
private static ResponseHandler createGoogleVertexAiResponseHandler() {
34+
return new GoogleVertexAiUnifiedChatCompletionResponseHandler(
35+
"Google Vertex AI chat completion",
36+
GoogleVertexAiChatCompletionResponseEntity::fromResponse
37+
);
38+
}
39+
40+
private final GoogleVertexAiChatCompletionModel model;
41+
42+
public GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) {
43+
super(threadPool, model, RateLimitGrouping.of(model));
44+
this.model = model;
45+
}
46+
47+
record RateLimitGrouping(int projectIdHash) {
48+
public static RateLimitGrouping of(GoogleVertexAiChatCompletionModel model) {
49+
Objects.requireNonNull(model);
50+
return new RateLimitGrouping(model.rateLimitServiceSettings().projectId().hashCode());
51+
}
52+
}
53+
54+
public static GoogleVertexAiCompletionRequestManager of(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) {
55+
Objects.requireNonNull(model);
56+
Objects.requireNonNull(threadPool);
57+
58+
return new GoogleVertexAiCompletionRequestManager(model, threadPool);
59+
}
60+
61+
@Override
62+
public void execute(
63+
InferenceInputs inferenceInputs,
64+
RequestSender requestSender,
65+
Supplier<Boolean> hasRequestCompletedFunction,
66+
ActionListener<InferenceServiceResults> listener
67+
) {
68+
69+
var chatInputs = (UnifiedChatInput) inferenceInputs;
70+
var request = new GoogleVertexAiUnifiedChatCompletionRequest(chatInputs, model);
71+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
72+
}
73+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99

1010
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1111
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
12+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1213
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1314
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1415
import org.elasticsearch.xpack.inference.external.request.Request;
1516
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;
1617

18+
import java.util.function.Function;
19+
1720
import static org.elasticsearch.core.Strings.format;
1821

1922
public class GoogleVertexAiResponseHandler extends BaseResponseHandler {
@@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
2427
super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse);
2528
}
2629

30+
public GoogleVertexAiResponseHandler(
31+
String requestType,
32+
ResponseParser parseFunction,
33+
Function<HttpResult, ErrorResponse> errorParseFunction,
34+
boolean canHandleStreamingResponses
35+
) {
36+
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
37+
}
38+
2739
@Override
2840
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
2941
if (result.isSuccessfulResponse()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,8 @@ public static Map<String, SettingsConfiguration> get() {
124124
var configurationMap = new HashMap<String, SettingsConfiguration>();
125125
configurationMap.put(
126126
SERVICE_ACCOUNT_JSON,
127-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription(
128-
"API Key for the provider you're connecting to."
129-
)
127+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
128+
.setDescription("API Key for the provider you're connecting to.")
130129
.setLabel("Credentials JSON")
131130
.setRequired(true)
132131
.setSensitive(true)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.rest.RestStatus;
3030
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3131
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
32+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3233
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3334
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3435
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -38,6 +39,7 @@
3839
import org.elasticsearch.xpack.inference.services.ServiceComponents;
3940
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4041
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
42+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
4143
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
4244
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
4345
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
@@ -47,25 +49,31 @@
4749
import java.util.HashMap;
4850
import java.util.List;
4951
import java.util.Map;
52+
import java.util.Set;
5053

54+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
5155
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
5256
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5357
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
5458
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
5559
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5660
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5761
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
58-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
5962
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
6063
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
6164
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
65+
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;
6266

6367
public class GoogleVertexAiService extends SenderService {
6468

6569
public static final String NAME = "googlevertexai";
6670

6771
private static final String SERVICE_NAME = "Google Vertex AI";
68-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
72+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
73+
TaskType.TEXT_EMBEDDING,
74+
TaskType.RERANK,
75+
TaskType.CHAT_COMPLETION
76+
);
6977

7078
public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
7179
InputType.INGEST,
@@ -76,6 +84,11 @@ public class GoogleVertexAiService extends SenderService {
7684
InputType.INTERNAL_SEARCH
7785
);
7886

87+
@Override
88+
public Set<TaskType> supportedStreamingTasks() {
89+
return EnumSet.of(TaskType.CHAT_COMPLETION);
90+
}
91+
7992
public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
8093
super(factory, serviceComponents);
8194
}
@@ -220,7 +233,18 @@ protected void doUnifiedCompletionInfer(
220233
TimeValue timeout,
221234
ActionListener<InferenceServiceResults> listener
222235
) {
223-
throwUnsupportedUnifiedCompletionOperation(NAME);
236+
if (model instanceof GoogleVertexAiChatCompletionModel == false) {
237+
listener.onFailure(createInvalidModelException(model));
238+
return;
239+
}
240+
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
241+
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());
242+
243+
var manager = GoogleVertexAiCompletionRequestManager.of(updatedChatCompletionModel, getServiceComponents().threadPool());
244+
245+
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
246+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
247+
action.execute(inputs, timeout, listener);
224248
}
225249

226250
@Override
@@ -320,6 +344,17 @@ private static GoogleVertexAiModel createModel(
320344
secretSettings,
321345
context
322346
);
347+
348+
case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
349+
inferenceEntityId,
350+
taskType,
351+
NAME,
352+
serviceSettings,
353+
taskSettings,
354+
secretSettings,
355+
context
356+
);
357+
323358
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
324359
};
325360
}
@@ -348,7 +383,7 @@ public static InferenceServiceConfiguration get() {
348383

349384
configurationMap.put(
350385
LOCATION,
351-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription(
386+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
352387
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
353388
+ "For more information, refer to the {geminiVertexAIDocs}."
354389
)

0 commit comments

Comments
 (0)