Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
00a6636
Implemented ChatCompletion task for Google VertexAI with Gemini Models
lhoet-google Apr 29, 2025
9be2a44
changelog
lhoet-google May 16, 2025
c2387e8
System Instruction bugfix
lhoet-google May 19, 2025
50770ea
Mapping role assistant -> model in vertex ai chat completion request …
lhoet-google May 19, 2025
42cbbe2
GoogleVertexAI chat completion using SSE events. Removed JsonArrayEve…
lhoet-google May 20, 2025
fe8e336
Removed buffer from GoogleVertexAiUnifiedStreamingProcessor
lhoet-google May 20, 2025
7c24f93
Casting inference inputs with `castoTo`
lhoet-google May 21, 2025
2140d05
Registered GoogleVertexAiChatCompletionServiceSettings in InferenceNa…
lhoet-google May 21, 2025
42dd376
Changed transport version to 8_19 for vertexai chatcompletion
lhoet-google May 21, 2025
0863316
Fix to transport version. Moved ML_INFERENCE_VERTEXAI_CHATCOMPLETION_…
lhoet-google May 21, 2025
f080e96
VertexAI Chat completion request entity jsonStringToMap using `ensure…
lhoet-google May 21, 2025
8f6648f
Fixed TransportVersions. Left vertexAi chat completion 8_19 and added…
lhoet-google May 22, 2025
848dc7a
Refactor switch statements by if-else for older java compatibility. I…
lhoet-google May 22, 2025
59862c6
Removed GoogleVertexAiChatCompletionResponseEntity and refactored cod…
lhoet-google May 22, 2025
93a7ca7
Removed redundant test `testUnifiedCompletionInfer_WithGoogleVertexAi…
lhoet-google May 22, 2025
7b99b1d
Returning whole body when fail to parse response from VertexAI
lhoet-google May 22, 2025
c05655f
Refactor use GenericRequestManager instead of GoogleVertexAiCompletio…
lhoet-google May 23, 2025
acc864f
Refactored to constructorArg for mandatory args in GoogleVertexAiUnif…
lhoet-google May 26, 2025
c371073
Changed transport version in GoogleVertexAiChatCompletionServiceSettings
lhoet-google May 26, 2025
efb90ba
Bugfix in tool calling with role tool
lhoet-google May 26, 2025
bb68715
Merge branch 'main' into google-vertexai-chatcompletion
lhoet-google May 26, 2025
1ead8c5
GoogleVertexAiModel added documentation info on rateLimitGroupingHash
leo-hoet May 27, 2025
ad9f0e1
Merge branch 'main' into google-vertexai-chatcompletion
leo-hoet May 27, 2025
f4057f3
Merge branch 'main' into google-vertexai-chatcompletion
jonathan-buttner May 28, 2025
38b9ca4
[CI] Auto commit changes from spotless
May 28, 2025
2e8dbee
Fix: using Locale.ROOT when calling toLowerCase
leo-hoet May 28, 2025
ddd19c5
Fix: Renamed test class to match convention & modified use of forbidd…
leo-hoet May 28, 2025
88a2780
Fix: Failing test in InferenceServicesIT
leo-hoet May 29, 2025
b841e4e
Merge branch 'main' into google-vertexai-chatcompletion
leo-hoet May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_37);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -260,7 +261,11 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_30);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_082_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.util.Map;
import java.util.Objects;

public abstract class GoogleVertexAiModel extends Model {
public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {

private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;

Expand Down Expand Up @@ -58,4 +59,15 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
public URI uri() {
return uri;
}

@Override
public int rateLimitGroupingHash() {
// In VertexAI rate limiting is scoped to the project and the model. URI already has this information so we are using that
return Objects.hash(uri);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, it's not based on the service account key information too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a link to the docs that indicates this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Will do. https://ai.google.dev/gemini-api/docs/rate-limits

Rate limits are applied per project, not per API key.

Also on the VertexAI quotas https://cloud.google.com/vertex-ai/docs/quotas#request_quotas

The following quotas apply to Vertex AI requests for a given project and supported region...

Some resources may not be affected by the region, but I choose to be conservative and go with a safe default

}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings().rateLimitSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
Expand All @@ -42,6 +44,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

Expand Down Expand Up @@ -84,6 +87,9 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
);
@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
Expand Down Expand Up @@ -240,7 +246,13 @@ protected void doUnifiedCompletionInfer(
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());

var manager = GoogleVertexAiCompletionRequestManager.of(updatedChatCompletionModel, getServiceComponents().threadPool());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
updatedChatCompletionModel,
COMPLETION_HANDLER,
(unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel),
UnifiedChatInput.class
);

var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.logging.LogManager;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.nio.charset.StandardCharsets;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -41,8 +43,10 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
private static final String ERROR_MESSAGE_FIELD = "message";
private static final String ERROR_STATUS_FIELD = "status";

public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
private static final ResponseParser noopParseFunction = (a, b) -> null;

public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) {
super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
}

@Override
Expand Down Expand Up @@ -141,9 +145,9 @@ static ErrorResponse fromResponse(HttpResult response) {
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
logger.warn("Failed to parse Google Vertex AI error response body", e);
var resultAsString = new String(response.body(), StandardCharsets.UTF_8);
return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", resultAsString));
}
return ErrorResponse.UNDEFINED_ERROR;
}

static ErrorResponse fromString(String response) {
Expand All @@ -153,9 +157,8 @@ static ErrorResponse fromString(String response) {
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
logger.warn("Failed to parse Google Vertex AI error string", e);
return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", response));
}
return ErrorResponse.UNDEFINED_ERROR;
}

private final int code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,17 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice

static {
PARSER.declareObjectArray(
ConstructingObjectParser.optionalConstructorArg(),
ConstructingObjectParser.constructorArg(),
(p, c) -> CandidateParser.parse(p),
new ParseField(CANDIDATES_FIELD)
);
PARSER.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
ConstructingObjectParser.constructorArg(),
(p, c) -> UsageMetadataParser.parse(p),
new ParseField(USAGE_METADATA_FIELD)
);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(MODEL_VERSION_FIELD));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(RESPONSE_ID_FIELD));
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_VERSION_FIELD));
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(RESPONSE_ID_FIELD));
}

public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException {
Expand All @@ -224,7 +224,7 @@ private static class CandidateParser {

static {
PARSER.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
ConstructingObjectParser.constructorArg(),
(p, c) -> ContentParser.parse(p),
new ParseField(CONTENT_FIELD)
);
Expand All @@ -248,9 +248,9 @@ private static class ContentParser {
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD));
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ROLE_FIELD));
PARSER.declareObjectArray(
ConstructingObjectParser.optionalConstructorArg(),
ConstructingObjectParser.constructorArg(),
(p, c) -> PartParser.parse(p),
new ParseField(PARTS_FIELD)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiCompletionRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;

import java.util.Map;
Expand All @@ -30,6 +36,11 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor

private final ServiceComponents serviceComponents;

static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
);
static final String USER_ROLE = "user";

public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
Expand All @@ -56,8 +67,16 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje

@Override
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI chat completion");
var requestManager = GoogleVertexAiCompletionRequestManager.of(model, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
COMPLETION_HANDLER,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19;
return TransportVersions.ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED;
}

@Override
Expand Down
Loading
Loading