Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128694.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128694
summary: "Adding Google VertexAI completion integration"
area: Inference
type: enhancement
issues: [ ]
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
new SettingsConfiguration.Builder(
EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION, TaskType.COMPLETION)
).setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
Expand Down Expand Up @@ -75,7 +76,8 @@ public class GoogleVertexAiService extends SenderService {
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
Expand Down Expand Up @@ -220,11 +222,11 @@ protected void doInfer(
return;
}

GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;
var completionModel = (GoogleVertexAiCompletionModel) model;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we can't rely on model being only a GoogleVertexAiCompletionModel, it could be a completion, rerank, or embedding model. Do we need it as a completion model here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. My understanding was this method was only called by the completion task, that's why I was downcasting it. Will revert this change


var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());

var action = googleVertexAiModel.accept(actionCreator, taskSettings);
var action = completionModel.accept(actionCreator, taskSettings);
action.execute(inputs, timeout, listener);
}

Expand Down Expand Up @@ -368,6 +370,16 @@ private static GoogleVertexAiModel createModel(
context
);

case COMPLETION -> new GoogleVertexAiCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -396,10 +408,11 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.COMPLETION))
.setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
.setLabel("GCP Region")
.setRequired(true)
.setSensitive(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

import java.nio.charset.StandardCharsets;
import java.util.Locale;
Expand All @@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
private static final String ERROR_MESSAGE_FIELD = "message";
private static final String ERROR_STATUS_FIELD = "status";

private static final ResponseParser noopParseFunction = (a, b) -> null;

public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) {
super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support streaming and non-streaming for completion I think we'll need a slightly different inheritance hierarchy.

For example take a look at openai: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java#L133-L140

These are the changes I think we need to make:

  • Let's override the parseResults in GoogleVertexAiResponseHandler, it can be identical to GoogleVertexAiUnifiedChatCompletionResponseHandler, except we'll need to return StreamingChatCompletionResults instead of the unified version
    • If there's a better way of refactoring, like potentially moving the buildMidstreamError or some other functionality from the unified response handler up to a class that both completion and chat completion extend that might be better
  • In the GoogleVertexAiActionCreator we'll create a new response handler that leverages GoogleVertexAiCompletionResponseEntity::fromResponse for the non-streaming case

Let me know if you'd rather jump on a call to discuss this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry one correction, the way we report errors for the "unified" aka openai format is different from the "elasticsearch" way of return errors. So for streaming completion we don't want to follow what we're doing in GoogleVertexAiUnifiedChatCompletionResponseHandler because that is returning the errors in openai format. I would try to follow what we're doing in the in the link to the OpenAiResponseHandler that I linked. Hopefully we don't need to create a whole new streaming processor though.

We might need to do some refactoring but I would see if you could reuse GoogleVertexAiUnifiedStreamingProcessor for the parsing logic but we'll need to return a different result (specifically StreamingChatCompletionResults).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I managed to get this working as you suggested. The only hacky thing is that I have to add a method to the completion model updateUri that is called before making the request to ensure we are calling the right api. Take a look a let me know what you think. (It's still missing unit tests)

}

@Override
Expand All @@ -63,7 +61,6 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR

@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to keep this for streaming only

var responseStatusCode = result.response().getStatusLine().getStatusCode();
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.completion;

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiCompletionModel extends GoogleVertexAiChatCompletionModel {

public GoogleVertexAiCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets,
ConfigurationParseContext context
) {
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secrets, context);
try {
var modelServiceSettings = this.getServiceSettings();
this.uri = buildUri(modelServiceSettings.location(), modelServiceSettings.projectId(), modelServiceSettings.modelId());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}

}

public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
.setPathSegments(
GoogleVertexAiUtils.V1,
GoogleVertexAiUtils.PROJECTS,
projectId,
GoogleVertexAiUtils.LOCATIONS,
GoogleVertexAiUtils.GLOBAL,
GoogleVertexAiUtils.PUBLISHERS,
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
GoogleVertexAiUtils.MODELS,
format("%s:%s", model, GoogleVertexAiUtils.GENERATE_CONTENT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GENERATE_CONTENT mean non-streaming? For other integrations (like openai) we support streaming for chat_completion, streaming and non-streaming for completion.

For openai and other integrations we achieved this through a field in the body of the request we sent that is a boolean to enable/disable streaming. For vertex ai, do we need separate URLs?

If so, I think we'll need to store a reference to both URLs and then determine which to use later in the execution flow (maybe in the request creation logic). I can't think of an easy way to determine whether we're streaming or not while we're building this class without having to put in some hacks. If you can think of one, let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I though only non streaming responses were returned by completion tasks that's why i was using this api https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent. In vertex AI there are two different API for streaming and non streaming. Let me see what I can do with the refactor you suggested and we can check it

)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public final class GoogleVertexAiUtils {

public static final String STREAM_GENERATE_CONTENT = "streamGenerateContent";

public static final String GENERATE_CONTENT = "generateContent";

public static final String QUERY_PARAM_ALT_SSE = "alt=sse";

private GoogleVertexAiUtils() {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.response;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedStreamingProcessor;

import java.io.IOException;
import java.nio.charset.StandardCharsets;

public class GoogleVertexAiCompletionResponseEntity {
/**
* Parses the response from Google Vertex AI's generateContent endpoint
* For a request like:
* <pre>
* <code>
* {
* "inputs": "Please summarize this text: some text"
* }
* </code>
* </pre>
*
* The response is a <a href="https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse">GenerateContentResponse</a> objects that looks like:
*
* <pre>
* <code>
*
* {
* "candidates": [
* {
* "content": {
* "role": "model",
* "parts": [
* {
* "text": "I am sorry, I cannot summarize the text because I do not have access to the text you are referring to."
* }
* ]
* },
* "finishReason": "STOP",
* "avgLogprobs": -0.19326641248620074
* }
* ],
* "usageMetadata": {
* "promptTokenCount": 71,
* "candidatesTokenCount": 23,
* "totalTokenCount": 94,
* "trafficType": "ON_DEMAND",
* "promptTokensDetails": [
* {
* "modality": "TEXT",
* "tokenCount": 71
* }
* ],
* "candidatesTokensDetails": [
* {
* "modality": "TEXT",
* "tokenCount": 23
* }
* ]
* },
* "modelVersion": "gemini-2.0-flash-001",
* "createTime": "2025-05-28T15:08:20.049493Z",
* "responseId": "5CY3aNWCA6mm4_UPr-zduAE"
* }
* </code>
* </pre>
*
* @param request The original request made to the service.
**/
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
var responseJson = new String(response.body(), StandardCharsets.UTF_8);

// Response from generateContent has the same shape as streamGenerateContent. We reuse the already implemented
// class to avoid code duplication

StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk;
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, responseJson)
) {
parser.nextToken();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we use moveToFirstToken().

chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser);
}
var results = chunk.choices().stream().map(choice -> choice.delta().content()).map(ChatCompletionResults.Result::new).toList();

return new ChatCompletionResults(results);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ public void testGetConfiguration() throws Exception {
{
"service": "googlevertexai",
"name": "Google Vertex AI",
"task_types": ["text_embedding", "rerank", "chat_completion"],
"task_types": ["text_embedding", "rerank", "completion", "chat_completion"],
"configurations": {
"service_account_json": {
"description": "API Key for the provider you're connecting to.",
Expand All @@ -983,7 +983,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": true,
"updatable": true,
"type": "str",
"supported_task_types": ["text_embedding", "rerank", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
},
"project_id": {
"description": "The GCP Project ID which has Vertex AI API(s) enabled. For more information on the URL, refer to the {geminiVertexAIDocs}.",
Expand All @@ -992,7 +992,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "rerank", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
},
"location": {
"description": "Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.",
Expand All @@ -1001,7 +1001,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "chat_completion"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.",
Expand All @@ -1010,7 +1010,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "int",
"supported_task_types": ["text_embedding", "rerank", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
},
"model_id": {
"description": "ID of the LLM you're using.",
Expand All @@ -1019,7 +1019,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "rerank", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.completion;

import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.*;

public class GoogleVertexAiCompletionModelTests extends ESTestCase {

private static final String DEFAULT_PROJECT_ID = "test-project";
private static final String DEFAULT_LOCATION = "us-central1";
private static final String DEFAULT_MODEL_ID = "gemini-pro";

public void testCreateModel() throws URISyntaxException {
var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID);
URI expectedUri = new URI(
Strings.format(
"https://%s-aiplatform.googleapis.com/v1/projects/%s" + "/locations/global/publishers/google/models/%s:generateContent",
DEFAULT_LOCATION,
DEFAULT_PROJECT_ID,
DEFAULT_MODEL_ID

)
);
assertThat(model.uri(), is(expectedUri));
}

private static GoogleVertexAiCompletionModel createCompletionModel(String projectId, String location, String modelId) {
return new GoogleVertexAiCompletionModel(
"google-vertex-ai-chat-test-id",
TaskType.CHAT_COMPLETION,
"google_vertex_ai",
new HashMap<>(Map.of("project_id", projectId, "location", location, "model_id", modelId)),
new HashMap<>(),
new HashMap<>(Map.of("service_account_json", "{}")),
ConfigurationParseContext.PERSISTENT
);
}
}
Loading
Loading