Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128694.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128694
summary: "Adding Google VertexAI completion integration"
area: Inference
type: enhancement
issues: [ ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should keep this class or delete it and use the base class instead.

Also, I am using GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class between GoogleVertexAiUnifiedChatCompletionResponseHandler and GoogleVertexAiChatCompletionResponseHandler to avoid extending the class hierarchy, but let me know if you think otherwise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should keep this class or delete it and use the base class instead.

Up to you, I'd probably use the base class. If you want to keep this one, then I don't think we need to accept the requestType. I think we can set it in this class directly.

Also, I am using GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class between GoogleVertexAiUnifiedChatCompletionResponseHandler and GoogleVertexAiChatCompletionResponseHandler to avoid extending the class hierarchy, but let me know if you think otherwise

Nice! That looks good.

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

public class GoogleVertexAiChatCompletionResponseHandler extends GoogleVertexAiResponseHandler {

public GoogleVertexAiChatCompletionResponseHandler(String requestType) {
super(
requestType,
GoogleVertexAiCompletionResponseEntity::fromResponse,
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
true
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.concurrent.Flow;
import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -66,4 +71,14 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr
private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var googleVertexAiProcessor = new GoogleVertexAiStreamingProcessor();

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(googleVertexAiProcessor);
return new StreamingChatCompletionResults(googleVertexAiProcessor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
new SettingsConfiguration.Builder(
EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION, TaskType.COMPLETION)
).setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
Expand Down Expand Up @@ -75,7 +76,8 @@ public class GoogleVertexAiService extends SenderService {
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
Expand All @@ -93,7 +95,7 @@ public class GoogleVertexAiService extends SenderService {

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -368,6 +370,16 @@ private static GoogleVertexAiModel createModel(
context
);

case COMPLETION -> new GoogleVertexAiCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -396,10 +408,11 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.COMPLETION))
.setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
.setLabel("GCP Region")
.setRequired(true)
.setSensitive(false)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.Deque;
import java.util.Objects;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {

@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig);

if (results.isEmpty()) {
upstream().request(1);
} else {
downstream().onNext(new StreamingChatCompletionResults.Results(results));
}
}

public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
String data = event.data();
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
moveToFirstToken(jsonParser);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is what OpenAiStreamingProcessor is using but I think we can omit this line and the ensureExpectedToken because the GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse uses the constructing object parser which should handle this validation.

OpenAiStreamingProcessor doesn't leverage the constructing object parser which is why we need to do validation.

ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);

var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);

return chunk.choices()
.stream()
.map(choice -> choice.delta())
.filter(Objects::nonNull)
.map(delta -> delta.content())
.filter(content -> content != null && content.isEmpty() == false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can use Strings.isNullOrEmpty here

.map(StreamingChatCompletionResults.Result::new);

} catch (IOException e) {
throw new ElasticsearchStatusException(
"Failed to parse event from inference provider: {}",
RestStatus.INTERNAL_SERVER_ERROR,
e,
event
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

import java.nio.charset.StandardCharsets;
import java.util.Locale;
Expand All @@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
private static final String ERROR_MESSAGE_FIELD = "message";
private static final String ERROR_STATUS_FIELD = "status";

private static final ResponseParser noopParseFunction = (a, b) -> null;

public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) {
super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support streaming and non-streaming for completion I think we'll need a slightly different inheritance hierarchy.

For example take a look at openai: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java#L133-L140

These are the changes I think we need to make:

  • Let's override the parseResults in GoogleVertexAiResponseHandler, it can be identical to GoogleVertexAiUnifiedChatCompletionResponseHandler, except we'll need to return StreamingChatCompletionResults instead of the unified version
    • If there's a better way of refactoring, like potentially moving the buildMidstreamError or some other functionality from the unified response handler up to a class that both completion and chat completion extend that might be better
  • In the GoogleVertexAiActionCreator we'll create a new response handler that leverages GoogleVertexAiCompletionResponseEntity::fromResponse for the non-streaming case

Let me know if you'd rather jump on a call to discuss this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry one correction, the way we report errors for the "unified" aka openai format is different from the "elasticsearch" way of return errors. So for streaming completion we don't want to follow what we're doing in GoogleVertexAiUnifiedChatCompletionResponseHandler because that is returning the errors in openai format. I would try to follow what we're doing in the in the link to the OpenAiResponseHandler that I linked. Hopefully we don't need to create a whole new streaming processor though.

We might need to do some refactoring but I would see if you could reuse GoogleVertexAiUnifiedStreamingProcessor for the parsing logic but we'll need to return a different result (specifically StreamingChatCompletionResults).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I managed to get this working as you suggested. The only hacky thing is that I have to add a method to the completion model updateUri that is called before making the request to ensure we are calling the right api. Take a look a let me know what you think. (It's still missing unit tests)

}

@Override
Expand All @@ -64,6 +62,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to keep this for streaming only


var responseStatusCode = result.response().getStatusLine().getStatusCode();
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
Expand Down Expand Up @@ -111,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
}
}

private static class GoogleVertexAiErrorResponse extends ErrorResponse {
public static class GoogleVertexAiErrorResponse extends ErrorResponse {
private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class);
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"google_vertex_ai_error_wrapper",
Expand All @@ -138,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse {
);
}

static ErrorResponse fromResponse(HttpResult response) {
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.services.googlevertexai.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
Expand All @@ -16,14 +18,17 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;

import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;

Expand All @@ -36,9 +41,10 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor

private final ServiceComponents serviceComponents;

static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another area that we need to refactor in the code base. The "unified" code path doesn't flow through the action creator like the other task types. Rerank, text embedding, and completion all originate through doInfer() and typically we call through to here to handle them based on the model class.

Chat completion aka unified flows through doUnifiedCompletionInfer() and doesn't use the action creator. So I think we can remove this.

I think we can also remove GoogleVertexAiCompletionModel and rely on GoogleVertexAiChatCompletionModel. If we receive GoogleVertexAiChatCompletionModel in here it should mean that we're trying to do completion not chat_completion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Didn't notice that. I think this is the way to go, it removes a lot of code and it makes it simpler to follow

"Google VertexAI chat completion"
);
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiChatCompletionResponseHandler("Google VertexAI completion");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this COMPLETION_HANDLER. We use chat/unified interchangeably (we need to clean that up throughout the code base).

static final String USER_ROLE = "user";

public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -72,11 +78,31 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
COMPLETION_HANDLER,
UNIFIED_CHAT_COMPLETION_HANDLER,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}

@Override
public ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);

var manager = new GenericRequestManager<>(serviceComponents.threadPool(), model, CHAT_COMPLETION_HANDLER, inputs -> {
try {
model.updateUri(inputs.stream());
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Error constructing URI for Google VertexAI completion",
RestStatus.INTERNAL_SERVER_ERROR,
e
);
}
return new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of calling model.updateUri we can just pass in the inputs and model like openai does. Then in the constructor of GoogleVertexAiUnifiedChatCompletionRequest we can check the inputs.stream() flag to determine if we're streaming and retrieve the appropriate URI.

}, ChatCompletionInput.class);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;

Expand All @@ -21,4 +22,6 @@ public interface GoogleVertexAiActionVisitor {
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings);
}
Loading
Loading