Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128105.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128105
summary: "Adding Google VertexAI chat completion integration"
area: Inference
type: enhancement
issues: [ ]
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ public void testGetServicesWithCompletionTaskType() throws IOException {

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
List.of(
"deepseek",
"elastic",
"openai",
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai"
).toArray()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
Expand Down Expand Up @@ -453,6 +454,15 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
GoogleVertexAiRerankTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
GoogleVertexAiChatCompletionServiceSettings.NAME,
GoogleVertexAiChatCompletionServiceSettings::new
)
);

}

private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.util.Map;
import java.util.Objects;

public abstract class GoogleVertexAiModel extends Model {
public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {

private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;

Expand Down Expand Up @@ -58,4 +59,18 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
public URI uri() {
return uri;
}

@Override
public int rateLimitGroupingHash() {
// In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that.
// API Key does not affect the quota
// https://ai.google.dev/gemini-api/docs/rate-limits
// https://cloud.google.com/vertex-ai/docs/quotas
return Objects.hash(uri);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings().rateLimitSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiResponseHandler extends BaseResponseHandler {
Expand All @@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse);
}

public GoogleVertexAiResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription(
"API Key for the provider you're connecting to."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
Expand All @@ -38,34 +41,42 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;

public class GoogleVertexAiService extends SenderService {

public static final String NAME = "googlevertexai";

private static final String SERVICE_NAME = "Google Vertex AI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
Expand All @@ -76,6 +87,15 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
Expand Down Expand Up @@ -220,7 +240,24 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof GoogleVertexAiChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());

var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
updatedChatCompletionModel,
COMPLETION_HANDLER,
(unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel),
UnifiedChatInput.class
);

var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}

@Override
Expand Down Expand Up @@ -320,6 +357,17 @@ private static GoogleVertexAiModel createModel(
secretSettings,
context
);

case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -348,7 +396,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription(
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
Expand Down
Loading