Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d429a31
VertexAI chat completion response entity with tests
lhoet-google Apr 29, 2025
00bfdb0
Modified build gradle to include google vertexai sdk
lhoet-google Apr 29, 2025
2378270
Google vertex ai chat completion model with tests
lhoet-google Apr 29, 2025
1f00974
Google vertex ai chat completion request with tests
lhoet-google Apr 30, 2025
970ab3c
TransportVersion
lhoet-google Apr 30, 2025
5428074
ChatCompletion TaskSettings & ServiceSettings
lhoet-google Apr 30, 2025
ee44f22
ChatCompletionRequestManager & tests
lhoet-google Apr 30, 2025
8160c2b
VertexAI Service and related classes. WIP & missing tests
lhoet-google Apr 30, 2025
ff68fbe
VertexAi ChatCompletion task settings fix.
lhoet-google May 5, 2025
29c7093
JsonArrayParts event processor & parser
lhoet-google May 6, 2025
bfd75b0
AI Service and service tests
lhoet-google May 6, 2025
2ebfac9
Unified chat completion response and request handlers. Also working w…
lhoet-google May 6, 2025
679ea80
StreamingProcessor now support tools. Added more tests
lhoet-google May 8, 2025
e611cc3
More tests for streaming processor
lhoet-google May 8, 2025
87e428a
Request entity tests
lhoet-google May 12, 2025
193d06d
Google vertexai unified chat completion entity now accepting tools an…
lhoet-google May 12, 2025
813a2e8
Serializing function call message
lhoet-google May 12, 2025
f1ab8cc
Response handler with tests
lhoet-google May 12, 2025
23c7d92
VertexAI chat completion req entity bugfixes
lhoet-google May 13, 2025
c45d23f
Bugfix in vertex ai unified chat completion req entity
lhoet-google May 13, 2025
a820d83
Bugfix in vertex ai unified streaming processor
lhoet-google May 13, 2025
d2f09cf
Removed google aiplatform sdk
lhoet-google May 13, 2025
bda94de
Renamed file to match class name for JsonArrayPartsEventParser
lhoet-google May 13, 2025
5dee072
Updated rate limit settings for vertex ai
lhoet-google May 13, 2025
2f75788
Deleted GoogleVertexAiChatCompletionTaskSettings
lhoet-google May 13, 2025
b50c911
VertexAI Unified chat completion request tests
lhoet-google May 14, 2025
d6ae90f
Fixed some tests
lhoet-google May 14, 2025
cbb387f
Fixed GoogleAIService get configuration tests
lhoet-google May 14, 2025
7e1c970
GoogleVertexAiCompletion action tests
lhoet-google May 14, 2025
5ab716f
Formatting
lhoet-google May 15, 2025
28aa464
Code style fix
lhoet-google May 15, 2025
2279391
Removed unnused variables
lhoet-google May 15, 2025
85af5c0
Function call id fixed
lhoet-google May 15, 2025
16c01b0
Bugfix
lhoet-google May 15, 2025
1732244
Merge branch 'main' into vertexai-chatcompletion
lhoet-google May 16, 2025
6cc165b
Testfix
lhoet-google May 16, 2025
c020122
Unit tests
beltrangs May 16, 2025
7821d58
Merge branch 'vertexai-chatcompletion' into google-chat-completion-tests
beltrangs May 16, 2025
8633659
Update ElasticInferenceServiceTests.java
beltrangs May 16, 2025
06020cc
Update GoogleVertexAiServiceTests.java
beltrangs May 16, 2025
5a2cfe5
Merge pull request #2 from beltrangslilly/google-chat-completion-tests
leo-hoet May 16, 2025
0cf1f3f
Merge branch 'vertexai-chatcompletion' of github.com:lhoet-google/ela…
lhoet-google May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_078_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;

/**
* Parses a stream of bytes that form a JSON array, where each element of the array
* is a JSON object. This parser extracts each complete JSON object from the array
* and emits it as byte array.
*
* Example of an expected stream:
* Chunk 1: [{"key":"val1"}
* Chunk 2: ,{"key2":"val2"}
* Chunk 3: ,{"key3":"val3"}, {"some":"object"}]
*
* This parser would emit four byte arrays, with data:
* 1. {"key":"val1"}
* 2. {"key2":"val2"}
* 3. {"key3":"val3"}
* 4. {"some":"object"}
*/
public class JsonArrayPartsEventParser {

// Buffer to hold bytes from the previous call if they formed an incomplete JSON object.
private final ByteArrayOutputStream incompletePart = new ByteArrayOutputStream();

public Deque<byte[]> parse(byte[] newBytes) {
if (newBytes == null || newBytes.length == 0) {
return new ArrayDeque<>(0);
}

ByteArrayOutputStream currentStream = new ByteArrayOutputStream();
try {
currentStream.write(incompletePart.toByteArray());
currentStream.write(newBytes);
} catch (IOException e) {
throw new UncheckedIOException("Error handling byte array streams", e);
}
incompletePart.reset();

byte[] dataToProcess = currentStream.toByteArray();
return parseInternal(dataToProcess);
}

private Deque<byte[]> parseInternal(byte[] data) {
int localBraceLevel = 0;
int objectStartIndex = -1;
Deque<byte[]> completedObjects = new ArrayDeque<>();

for (int i = 0; i < data.length; i++) {
char c = (char) data[i];

if (c == '{') {
if (localBraceLevel == 0) {
objectStartIndex = i;
}
localBraceLevel++;
} else if (c == '}') {
if (localBraceLevel > 0) {
localBraceLevel--;
if (localBraceLevel == 0) {
byte[] jsonObject = Arrays.copyOfRange(data, objectStartIndex, i + 1);
completedObjects.offer(jsonObject);
objectStartIndex = -1;
}
}
}
}

if (localBraceLevel > 0) {
incompletePart.write(data, objectStartIndex, data.length - objectStartIndex);
}
return completedObjects;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.util.Deque;

public class JsonArrayPartsEventProcessor extends DelegatingProcessor<HttpResult, Deque<byte[]>> {
private final JsonArrayPartsEventParser jsonArrayPartsEventParser;

public JsonArrayPartsEventProcessor(JsonArrayPartsEventParser jsonArrayPartsEventParser) {
this.jsonArrayPartsEventParser = jsonArrayPartsEventParser;
}

@Override
public void next(HttpResult item) {
if (item.isBodyEmpty()) {
upstream().request(1);
return;
}

var response = jsonArrayPartsEventParser.parse(item.body());
if (response.isEmpty()) {
upstream().request(1);
return;
}

downstream().onNext(response);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiChatCompletionResponseEntity;

import java.util.Objects;
import java.util.function.Supplier;

public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager {

private static final Logger logger = LogManager.getLogger(GoogleVertexAiCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler();

private static ResponseHandler createGoogleVertexAiResponseHandler() {
return new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google Vertex AI chat completion",
GoogleVertexAiChatCompletionResponseEntity::fromResponse
);
}

private final GoogleVertexAiChatCompletionModel model;

public GoogleVertexAiCompletionRequestManager(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model, RateLimitGrouping.of(model));
this.model = model;
}

record RateLimitGrouping(int projectIdHash) {
public static RateLimitGrouping of(GoogleVertexAiChatCompletionModel model) {
Objects.requireNonNull(model);
return new RateLimitGrouping(model.rateLimitServiceSettings().projectId().hashCode());
}
}

public static GoogleVertexAiCompletionRequestManager of(GoogleVertexAiChatCompletionModel model, ThreadPool threadPool) {
Objects.requireNonNull(model);
Objects.requireNonNull(threadPool);

return new GoogleVertexAiCompletionRequestManager(model, threadPool);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {

var chatInputs = (UnifiedChatInput) inferenceInputs;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var chatInputs = (UnifiedChatInput) inferenceInputs;
var chatInputs = inferenceInputs.castTo(UnifiedChatInput.class);

If the types are somehow wrong, this will throw a decorated IllegalArgumentException rather than the ClassCastException

var request = new GoogleVertexAiUnifiedChatCompletionRequest(chatInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiResponseHandler extends BaseResponseHandler {
Expand All @@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse);
}

public GoogleVertexAiResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription(
"API Key for the provider you're connecting to."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand All @@ -38,6 +39,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
Expand All @@ -47,25 +49,31 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;

public class GoogleVertexAiService extends SenderService {

public static final String NAME = "googlevertexai";

private static final String SERVICE_NAME = "Google Vertex AI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
Expand All @@ -76,6 +84,11 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
Expand Down Expand Up @@ -220,7 +233,18 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof GoogleVertexAiChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());

var manager = GoogleVertexAiCompletionRequestManager.of(updatedChatCompletionModel, getServiceComponents().threadPool());

var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}

@Override
Expand Down Expand Up @@ -320,6 +344,17 @@ private static GoogleVertexAiModel createModel(
secretSettings,
context
);

case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -348,7 +383,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription(
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
Expand Down
Loading