Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
00a6636
Implemented ChatCompletion task for Google VertexAI with Gemini Models
lhoet-google Apr 29, 2025
9be2a44
changelog
lhoet-google May 16, 2025
c2387e8
System Instruction bugfix
lhoet-google May 19, 2025
50770ea
Mapping role assistant -> model in vertex ai chat completion request …
lhoet-google May 19, 2025
42cbbe2
GoogleVertexAI chat completion using SSE events. Removed JsonArrayEve…
lhoet-google May 20, 2025
fe8e336
Removed buffer from GoogleVertexAiUnifiedStreamingProcessor
lhoet-google May 20, 2025
7c24f93
Casting inference inputs with `castoTo`
lhoet-google May 21, 2025
2140d05
Registered GoogleVertexAiChatCompletionServiceSettings in InferenceNa…
lhoet-google May 21, 2025
42dd376
Changed transport version to 8_19 for vertexai chatcompletion
lhoet-google May 21, 2025
0863316
Fix to transport version. Moved ML_INFERENCE_VERTEXAI_CHATCOMPLETION_…
lhoet-google May 21, 2025
f080e96
VertexAI Chat completion request entity jsonStringToMap using `ensure…
lhoet-google May 21, 2025
8f6648f
Fixed TransportVersions. Left vertexAi chat completion 8_19 and added…
lhoet-google May 22, 2025
848dc7a
Refactor switch statements by if-else for older java compatibility. I…
lhoet-google May 22, 2025
59862c6
Removed GoogleVertexAiChatCompletionResponseEntity and refactored cod…
lhoet-google May 22, 2025
93a7ca7
Removed redundant test `testUnifiedCompletionInfer_WithGoogleVertexAi…
lhoet-google May 22, 2025
7b99b1d
Returning whole body when fail to parse response from VertexAI
lhoet-google May 22, 2025
c05655f
Refactor use GenericRequestManager instead of GoogleVertexAiCompletio…
lhoet-google May 23, 2025
acc864f
Refactored to constructorArg for mandatory args in GoogleVertexAiUnif…
lhoet-google May 26, 2025
c371073
Changed transport version in GoogleVertexAiChatCompletionServiceSettings
lhoet-google May 26, 2025
efb90ba
Bugfix in tool calling with role tool
lhoet-google May 26, 2025
bb68715
Merge branch 'main' into google-vertexai-chatcompletion
lhoet-google May 26, 2025
1ead8c5
GoogleVertexAiModel added documentation info on rateLimitGroupingHash
leo-hoet May 27, 2025
ad9f0e1
Merge branch 'main' into google-vertexai-chatcompletion
leo-hoet May 27, 2025
f4057f3
Merge branch 'main' into google-vertexai-chatcompletion
jonathan-buttner May 28, 2025
38b9ca4
[CI] Auto commit changes from spotless
May 28, 2025
2e8dbee
Fix: using Locale.ROOT when calling toLowerCase
leo-hoet May 28, 2025
ddd19c5
Fix: Renamed test class to match convention & modified use of forbidd…
leo-hoet May 28, 2025
88a2780
Fix: Failing test in InferenceServicesIT
leo-hoet May 29, 2025
b841e4e
Merge branch 'main' into google-vertexai-chatcompletion
leo-hoet May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128105.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128105
summary: "Google VertexAI integration now supports chat_completion task"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
summary: "Google VertexAI integration now supports chat_completion task"
summary: "Adding Google VertexAI chat completion integration"

area: Inference
type: enhancement
issues: [ ]
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_078_0_00);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want to backport this to 8.19. To do that we need to reserve a transport version for 8.19 but in the main branch.

Let's add another transport version similar to what I did here: https://github.com/elastic/elasticsearch/pull/126805/files#diff-85e782e9e33a0f8ca8e99b41c17f9d04e3a7981d435abf44a3aa5d954a47cd8fR175

public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_30);

Or whatever the latest version number is (it might be 30, or 31 etc).


/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;

/**
* Parses a stream of bytes that form a JSON array, where each element of the array
* is a JSON object. This parser extracts each complete JSON object from the array
* and emits it as byte array.
*
* Example of an expected stream:
* Chunk 1: [{"key":"val1"}
* Chunk 2: ,{"key2":"val2"}
* Chunk 3: ,{"key3":"val3"}, {"some":"object"}]
*
* This parser would emit four byte arrays, with data:
* 1. {"key":"val1"}
* 2. {"key2":"val2"}
* 3. {"key3":"val3"}
* 4. {"some":"object"}
*/
public class JsonArrayPartsEventParser {

// Buffer to hold bytes from the previous call if they formed an incomplete JSON object.
private final ByteArrayOutputStream incompletePart = new ByteArrayOutputStream();

public Deque<byte[]> parse(byte[] newBytes) {
if (newBytes == null || newBytes.length == 0) {
return new ArrayDeque<>(0);
}

ByteArrayOutputStream currentStream = new ByteArrayOutputStream();
try {
currentStream.write(incompletePart.toByteArray());
currentStream.write(newBytes);
} catch (IOException e) {
throw new UncheckedIOException("Error handling byte array streams", e);
}
incompletePart.reset();

byte[] dataToProcess = currentStream.toByteArray();
return parseInternal(dataToProcess);
}

private Deque<byte[]> parseInternal(byte[] data) {
int localBraceLevel = 0;
int objectStartIndex = -1;
Deque<byte[]> completedObjects = new ArrayDeque<>();

for (int i = 0; i < data.length; i++) {
char c = (char) data[i];

if (c == '{') {
if (localBraceLevel == 0) {
objectStartIndex = i;
}
localBraceLevel++;
} else if (c == '}') {
if (localBraceLevel > 0) {
localBraceLevel--;
if (localBraceLevel == 0) {
byte[] jsonObject = Arrays.copyOfRange(data, objectStartIndex, i + 1);
completedObjects.offer(jsonObject);
objectStartIndex = -1;
}
}
}
}

if (localBraceLevel > 0) {
incompletePart.write(data, objectStartIndex, data.length - objectStartIndex);
}
return completedObjects;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.util.Deque;

public class JsonArrayPartsEventProcessor extends DelegatingProcessor<HttpResult, Deque<byte[]>> {
private final JsonArrayPartsEventParser jsonArrayPartsEventParser;

public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEventParser) {
this.jsonArrayPartsEventParser = jsonArrayPartsEventParser;
}

@Override
public void next(HttpResult item) {
if (item.isBodyEmpty()) {
upstream().request(1);
return;
}

var response = jsonArrayPartsEventParser.parse(item.body());
if (response.isEmpty()) {
upstream().request(1);
return;
}

downstream().onNext(response);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity;

import java.util.Objects;
import java.util.function.Supplier;

public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to transition away from the request manager pattern to avoid the extra class since all the classes are pretty similar.

Here's an example of how we implemented it for voyageai: #124512

Here's how we do it for chat completions in openai: https://github.com/elastic/elasticsearch/blob/d2be03c946c94943dca8fe5da75a125fa70ddaa6/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreator.java

If we could switch to using a generic request manager that'd be great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made to swtich to use a generic request and it compiles and works fine. My only fear is that I had to change the base class of GoogleVertexAiModel from Model to RateLimitGroupingModel. Does that have any implication that i am not aware of?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! No that should be fine. Thanks for making that change.


private static final Logger logger = LogManager.getLogger(GoogleVertexAiCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler();

private static ResponseHandler createGoogleVertexAiResponseHandler() {
return new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google Vertex AI chat completion",
GoogleVertexAiChatCompletionResponseEntity::fromResponse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I mentioned in another comment, I think we can expose a different constructor that doesn't require the parsing function to be passed in. If we do end up implementing the completion task type we'll probably need a separate ResponseHandler and that's when we'd need to pass in a parse function here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated since I removed GoogleVertexAiCompletionRequestManager in favor of GenericRequestManager

);
}

private final GoogleVertexAiChatCompletionModel model;

public GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model, RateLimitGrouping.of(model));
this.model = model;
}

record RateLimitGrouping(int projectIdHash) {
public static RateLimitGrouping of(GoogleVertexAiChatCompletionModel model) {
Objects.requireNonNull(model);
return new RateLimitGrouping(model.rateLimitServiceSettings().projectId().hashCode());
}
}

public static GoogleVertexAiCompletionRequestManager of(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) {
Objects.requireNonNull(model);
Objects.requireNonNull(threadPool);

return new GoogleVertexAiCompletionRequestManager(model, threadPool);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {

var chatInputs = (UnifiedChatInput) inferenceInputs;
var request = new GoogleVertexAiUnifiedChatCompletionRequest(chatInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiResponseHandler extends BaseResponseHandler {
Expand All @@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse);
}

public GoogleVertexAiResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription(
"API Key for the provider you're connecting to."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand All @@ -38,6 +39,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
Expand All @@ -47,25 +49,31 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;

public class GoogleVertexAiService extends SenderService {

public static final String NAME = "googlevertexai";

private static final String SERVICE_NAME = "Google Vertex AI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
Expand All @@ -76,6 +84,11 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
Expand Down Expand Up @@ -220,7 +233,18 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof GoogleVertexAiChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());

var manager = GoogleVertexAiCompletionRequestManager.of(updatedChatCompletionModel, getServiceComponents().threadPool());

var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}

@Override
Expand Down Expand Up @@ -320,6 +344,17 @@ private static GoogleVertexAiModel createModel(
secretSettings,
context
);

case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -348,7 +383,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription(
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
Expand Down
Loading
Loading