Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should keep this class or delete it and use the base class instead.

Also, I am using GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class between GoogleVertexAiUnifiedChatCompletionResponseHandler and GoogleVertexAiChatCompletionResponseHandler to avoid extending the class hierarchy, but let me know if you think otherwise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should keep this class or delete it and use the base class instead.

Up to you, I'd probably use the base class. If you want to keep this one, then I don't think we need to accept the requestType. I think we can set it in this class directly.

Also, I am using GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class between GoogleVertexAiUnifiedChatCompletionResponseHandler and GoogleVertexAiChatCompletionResponseHandler to avoid extending the class hierarchy, but let me know if you think otherwise

Nice! That looks good.

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

public class GoogleVertexAiChatCompletionResponseHandler extends GoogleVertexAiResponseHandler {

public GoogleVertexAiChatCompletionResponseHandler(String requestType) {
super(
requestType,
GoogleVertexAiCompletionResponseEntity::fromResponse,
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
true
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.concurrent.Flow;
import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -66,4 +71,14 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr
private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var googleVertexAiProcessor = new GoogleVertexAiStreamingProcessor();

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(googleVertexAiProcessor);
return new StreamingChatCompletionResults(googleVertexAiProcessor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public class GoogleVertexAiService extends SenderService {

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -222,11 +222,11 @@ protected void doInfer(
return;
}

var completionModel = (GoogleVertexAiCompletionModel) model;
GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;

var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());

var action = completionModel.accept(actionCreator, taskSettings);
var action = googleVertexAiModel.accept(actionCreator, taskSettings);
action.execute(inputs, timeout, listener);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.Deque;
import java.util.Objects;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {

@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig);

if (results.isEmpty()) {
upstream().request(1);
} else {
downstream().onNext(new StreamingChatCompletionResults.Results(results));
}
}

public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
String data = event.data();
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
moveToFirstToken(jsonParser);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is what OpenAiStreamingProcessor is using but I think we can omit this line and the ensureExpectedToken because the GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse uses the constructing object parser which should handle this validation.

OpenAiStreamingProcessor doesn't leverage the constructing object parser which is why we need to do validation.

ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);

var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);

return chunk.choices()
.stream()
.map(choice -> choice.delta())
.filter(Objects::nonNull)
.map(delta -> delta.content())
.filter(content -> content != null && content.isEmpty() == false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can use Strings.isNullOrEmpty here

.map(StreamingChatCompletionResults.Result::new);

} catch (IOException e) {
throw new ElasticsearchStatusException(
"Failed to parse event from inference provider: {}",
RestStatus.INTERNAL_SERVER_ERROR,
e,
event
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR

@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";

var responseStatusCode = result.response().getStatusLine().getStatusCode();
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
Expand Down Expand Up @@ -108,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
}
}

private static class GoogleVertexAiErrorResponse extends ErrorResponse {
public static class GoogleVertexAiErrorResponse extends ErrorResponse {
private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class);
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"google_vertex_ai_error_wrapper",
Expand All @@ -135,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse {
);
}

static ErrorResponse fromResponse(HttpResult response) {
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.services.googlevertexai.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
Expand All @@ -16,14 +18,17 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;

import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;

Expand All @@ -36,9 +41,10 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor

private final ServiceComponents serviceComponents;

static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another area that we need to refactor in the code base. The "unified" code path doesn't flow through the action creator like the other task types. Rerank, text embedding, and completion all originate through doInfer() and typically we call through to here to handle them based on the model class.

Chat completion aka unified flows through doUnifiedCompletionInfer() and doesn't use the action creator. So I think we can remove this.

I think we can also remove GoogleVertexAiCompletionModel and rely on GoogleVertexAiChatCompletionModel. If we receive GoogleVertexAiChatCompletionModel in here it should mean that we're trying to do completion not chat_completion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Didn't notice that. I think this is the way to go, it removes a lot of code and it makes it simpler to follow

"Google VertexAI chat completion"
);
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiChatCompletionResponseHandler("Google VertexAI completion");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this COMPLETION_HANDLER. We use chat/unified interchangeably (we need to clean that up throughout the code base).

static final String USER_ROLE = "user";

public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -72,11 +78,31 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
COMPLETION_HANDLER,
UNIFIED_CHAT_COMPLETION_HANDLER,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}

@Override
public ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);

var manager = new GenericRequestManager<>(serviceComponents.threadPool(), model, CHAT_COMPLETION_HANDLER, inputs -> {
try {
model.updateUri(inputs.stream());
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Error constructing URI for Google VertexAI completion",
RestStatus.INTERNAL_SERVER_ERROR,
e
);
}
return new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of calling model.updateUri we can just pass in the inputs and model like openai does. Then in the constructor of GoogleVertexAiUnifiedChatCompletionRequest we can check the inputs.stream() flag to determine if we're streaming and retrieve the appropriate URI.

}, ChatCompletionInput.class);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;

Expand All @@ -21,4 +22,6 @@ public interface GoogleVertexAiActionVisitor {
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;

import java.net.URI;
Expand Down Expand Up @@ -39,6 +41,25 @@ public GoogleVertexAiCompletionModel(

}

public void updateUri(boolean isStream) throws URISyntaxException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making the uri mutable here, how about we expose two methods, one to retrieve the streaming URI and one to retrieve the non-streaming URI? We'll need to change how the base model (GoogleVertexAiModel) handles the URI and the rate limiting.

Does google allow streaming text embeddings?

For the changes to GoogleVertexAiModel how about we rename uri and getUri to nonStreamingUri and getNonStreamingUri.

Then in GoogleVertexAiChatCompletionModel we can expose a new method that returns the streaming URI.

For the rate limiting maybe we could make it an abstract method in the GoogleVertexAiModel and have the child classes handle it. I think it'll be nearly the same for all of them except maybe rerank which I think doesn't use location. That way we don't need to rely on the URI for the rate limiting but the specific pieces (except streaming vs non-streaming) that would build the URI.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense.

Does google allow streaming text embeddings?

I think so but it's not as straightforward as with the generate content API. The streaming predict API https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/serverStreamingPredict?hl=en returns a different object than the predict API

For the rate limiting maybe we could make it an abstract method in the GoogleVertexAiModel and have the child classes handle it

Sounds good!

var location = getServiceSettings().location();
var projectId = getServiceSettings().projectId();
var model = getServiceSettings().modelId();

// Google VertexAI generates streaming response using another API. We call this
// method before making the request to be sure we are calling the right API
if (isStream) {
this.uri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model);
} else {
this.uri = GoogleVertexAiCompletionModel.buildUri(location, projectId, model);
}
}

@Override
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this, taskSettings);
}

public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;

import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class GoogleVertexAiCompletionResponseEntity {
/**
* Parses the response from Google Vertex AI's generateContent endpoint
Expand Down Expand Up @@ -91,7 +93,7 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, responseJson)
) {
parser.nextToken();
moveToFirstToken(parser);
chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser);
}
var results = chunk.choices().stream().map(choice -> choice.delta().content()).map(ChatCompletionResults.Result::new).toList();
Expand Down
Loading
Loading