Skip to content

Commit d7acf80

Browse files
leo-hoetlhoet-googlejonathan-buttnerelasticsearchmachine
authored andcommitted
Implemented ChatCompletion task for Google VertexAI with Gemini Models (elastic#128105)
* Implemented ChatCompletion task for Google VertexAI with Gemini Models * changelog * System Instruction bugfix * Mapping role assistant -> model in vertex ai chat completion request for compatibility * GoogleVertexAI chat completion using SSE events. Removed JsonArrayEventParser * Removed buffer from GoogleVertexAiUnifiedStreamingProcessor * Casting inference inputs with `castoTo` * Registered GoogleVertexAiChatCompletionServiceSettings in InferenceNamedWriteablesProvider. Added InferenceSettingsTests * Changed transport version to 8_19 for vertexai chatcompletion * Fix to transport version. Moved ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED to the right location * VertexAI Chat completion request entity jsonStringToMap using `ensureExpectedToken` * Fixed TransportVersions. Left vertexAi chat completion 8_19 and added new one for ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDDED * Refactor switch statements by if-else for older java compatibility. Improved indentation via `{}` * Removed GoogleVertexAiChatCompletionResponseEntity and refactored code around it. * Removed redundant test `testUnifiedCompletionInfer_WithGoogleVertexAiModel` * Returning whole body when fail to parse response from VertexAI * Refactor use GenericRequestManager instead of GoogleVertexAiCompletionRequestManager * Refactored to constructorArg for mandatory args in GoogleVertexAiUnifiedStreamingProcessor * Changed transport version in GoogleVertexAiChatCompletionServiceSettings * Bugfix in tool calling with role tool * GoogleVertexAiModel added documentation info on rateLimitGroupingHash * [CI] Auto commit changes from spotless * Fix: using Locale.ROOT when calling toLowerCase * Fix: Renamed test class to match convention & modified use of forbidden api * Fix: Failing test in InferenceServicesIT --------- Co-authored-by: lhoet <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: elasticsearchmachine <[email protected]>
1 parent abec0ce commit d7acf80

25 files changed

+3275
-19
lines changed

docs/changelog/128105.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 128105
2+
summary: "Adding Google VertexAI chat completion integration"
3+
area: Inference
4+
type: enhancement
5+
issues: [ ]

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ static TransportVersion def(int id) {
181181
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
182182
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
183183
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
184+
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);
184185
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
185186
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
186187
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -266,6 +267,8 @@ static TransportVersion def(int id) {
266267
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
267268
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
268269
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
270+
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_083_0_00);
271+
269272
/*
270273
* STOP! READ THIS FIRST! No, really,
271274
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,22 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
151151

152152
public void testGetServicesWithChatCompletionTaskType() throws IOException {
153153
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
154-
assertThat(services.size(), equalTo(6));
154+
assertThat(services.size(), equalTo(7));
155155

156156
var providers = providers(services);
157157

158158
assertThat(
159159
providers,
160160
containsInAnyOrder(
161-
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
161+
List.of(
162+
"deepseek",
163+
"elastic",
164+
"openai",
165+
"streaming_completion_test_service",
166+
"hugging_face",
167+
"amazon_sagemaker",
168+
"googlevertexai"
169+
).toArray()
162170
)
163171
);
164172
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
7474
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
7575
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
76+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings;
7677
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
7778
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
7879
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
@@ -453,6 +454,15 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
453454
GoogleVertexAiRerankTaskSettings::new
454455
)
455456
);
457+
458+
namedWriteables.add(
459+
new NamedWriteableRegistry.Entry(
460+
ServiceSettings.class,
461+
GoogleVertexAiChatCompletionServiceSettings.NAME,
462+
GoogleVertexAiChatCompletionServiceSettings::new
463+
)
464+
);
465+
456466
}
457467

458468
private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai;
99

10-
import org.elasticsearch.inference.Model;
1110
import org.elasticsearch.inference.ModelConfigurations;
1211
import org.elasticsearch.inference.ModelSecrets;
1312
import org.elasticsearch.inference.ServiceSettings;
1413
import org.elasticsearch.inference.TaskSettings;
1514
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1616
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
17+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1718

1819
import java.net.URI;
1920
import java.util.Map;
2021
import java.util.Objects;
2122

22-
public abstract class GoogleVertexAiModel extends Model {
23+
public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {
2324

2425
private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;
2526

@@ -58,4 +59,18 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
5859
public URI uri() {
5960
return uri;
6061
}
62+
63+
@Override
64+
public int rateLimitGroupingHash() {
65+
// In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that.
66+
// API Key does not affect the quota
67+
// https://ai.google.dev/gemini-api/docs/rate-limits
68+
// https://cloud.google.com/vertex-ai/docs/quotas
69+
return Objects.hash(uri);
70+
}
71+
72+
@Override
73+
public RateLimitSettings rateLimitSettings() {
74+
return rateLimitServiceSettings().rateLimitSettings();
75+
}
6176
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99

1010
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1111
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
12+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1213
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1314
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1415
import org.elasticsearch.xpack.inference.external.request.Request;
1516
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;
1617

18+
import java.util.function.Function;
19+
1720
import static org.elasticsearch.core.Strings.format;
1821

1922
public class GoogleVertexAiResponseHandler extends BaseResponseHandler {
@@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
2427
super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse);
2528
}
2629

30+
public GoogleVertexAiResponseHandler(
31+
String requestType,
32+
ResponseParser parseFunction,
33+
Function<HttpResult, ErrorResponse> errorParseFunction,
34+
boolean canHandleStreamingResponses
35+
) {
36+
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
37+
}
38+
2739
@Override
2840
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
2941
if (result.isSuccessfulResponse()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,8 @@ public static Map<String, SettingsConfiguration> get() {
124124
var configurationMap = new HashMap<String, SettingsConfiguration>();
125125
configurationMap.put(
126126
SERVICE_ACCOUNT_JSON,
127-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription(
128-
"API Key for the provider you're connecting to."
129-
)
127+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
128+
.setDescription("API Key for the provider you're connecting to.")
130129
.setLabel("Credentials JSON")
131130
.setRequired(true)
132131
.setSensitive(true)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
import org.elasticsearch.rest.RestStatus;
3030
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3131
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
32+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
33+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3234
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
35+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3336
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3437
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3538
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -38,34 +41,42 @@
3841
import org.elasticsearch.xpack.inference.services.ServiceComponents;
3942
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4043
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
44+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
4145
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
4246
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
47+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
4348
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
4449
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4550

4651
import java.util.EnumSet;
4752
import java.util.HashMap;
4853
import java.util.List;
4954
import java.util.Map;
55+
import java.util.Set;
5056

57+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
5158
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
5259
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5360
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
5461
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
5562
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5663
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5764
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
58-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
5965
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
6066
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
6167
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
68+
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;
6269

6370
public class GoogleVertexAiService extends SenderService {
6471

6572
public static final String NAME = "googlevertexai";
6673

6774
private static final String SERVICE_NAME = "Google Vertex AI";
68-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
75+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
76+
TaskType.TEXT_EMBEDDING,
77+
TaskType.RERANK,
78+
TaskType.CHAT_COMPLETION
79+
);
6980

7081
public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
7182
InputType.INGEST,
@@ -76,6 +87,15 @@ public class GoogleVertexAiService extends SenderService {
7687
InputType.INTERNAL_SEARCH
7788
);
7889

90+
private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
91+
"Google VertexAI chat completion"
92+
);
93+
94+
@Override
95+
public Set<TaskType> supportedStreamingTasks() {
96+
return EnumSet.of(TaskType.CHAT_COMPLETION);
97+
}
98+
7999
public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
80100
super(factory, serviceComponents);
81101
}
@@ -220,7 +240,24 @@ protected void doUnifiedCompletionInfer(
220240
TimeValue timeout,
221241
ActionListener<InferenceServiceResults> listener
222242
) {
223-
throwUnsupportedUnifiedCompletionOperation(NAME);
243+
if (model instanceof GoogleVertexAiChatCompletionModel == false) {
244+
listener.onFailure(createInvalidModelException(model));
245+
return;
246+
}
247+
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
248+
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());
249+
250+
var manager = new GenericRequestManager<>(
251+
getServiceComponents().threadPool(),
252+
updatedChatCompletionModel,
253+
COMPLETION_HANDLER,
254+
(unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel),
255+
UnifiedChatInput.class
256+
);
257+
258+
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
259+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
260+
action.execute(inputs, timeout, listener);
224261
}
225262

226263
@Override
@@ -320,6 +357,17 @@ private static GoogleVertexAiModel createModel(
320357
secretSettings,
321358
context
322359
);
360+
361+
case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
362+
inferenceEntityId,
363+
taskType,
364+
NAME,
365+
serviceSettings,
366+
taskSettings,
367+
secretSettings,
368+
context
369+
);
370+
323371
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
324372
};
325373
}
@@ -348,7 +396,7 @@ public static InferenceServiceConfiguration get() {
348396

349397
configurationMap.put(
350398
LOCATION,
351-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription(
399+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
352400
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
353401
+ "For more information, refer to the {geminiVertexAIDocs}."
354402
)

0 commit comments

Comments
 (0)