Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should keep this class or delete it and use the base class instead.

Also, I am using GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class between GoogleVertexAiUnifiedChatCompletionResponseHandler and GoogleVertexAiChatCompletionResponseHandler to avoid extending the class hierarchy, but let me know if you think otherwise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should keep this class or delete it and use the base class instead.

Up to you, I'd probably use the base class. If you want to keep this one, then I don't think we need to accept the requestType. I think we can set it in this class directly.

Also, I am using GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse. I preferred to use it that way and avoid putting that in a common class between GoogleVertexAiUnifiedChatCompletionResponseHandler and GoogleVertexAiChatCompletionResponseHandler to avoid extending the class hierarchy, but let me know if you think otherwise

Nice! That looks good.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {

private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;

protected URI uri;
protected URI nonStreamingUri;

public GoogleVertexAiModel(
ModelConfigurations configurations,
Expand All @@ -39,14 +39,14 @@ public GoogleVertexAiModel(
public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);

uri = model.uri();
nonStreamingUri = model.nonStreamingUri();
rateLimitServiceSettings = model.rateLimitServiceSettings();
}

public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) {
super(model, taskSettings);

uri = model.uri();
nonStreamingUri = model.nonStreamingUri();
rateLimitServiceSettings = model.rateLimitServiceSettings();
}

Expand All @@ -56,17 +56,8 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

public URI uri() {
return uri;
}

@Override
public int rateLimitGroupingHash() {
// In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that.
// API Key does not affect the quota
// https://ai.google.dev/gemini-api/docs/rate-limits
// https://cloud.google.com/vertex-ai/docs/quotas
return Objects.hash(uri);
public URI nonStreamingUri() {
return nonStreamingUri;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
Expand Down Expand Up @@ -89,7 +88,7 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
public static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
);

Expand Down Expand Up @@ -360,17 +359,7 @@ private static GoogleVertexAiModel createModel(
context
);

case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

case COMPLETION -> new GoogleVertexAiCompletionModel(
case CHAT_COMPLETION, COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -24,9 +25,6 @@
import java.util.Objects;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {

@Override
Expand All @@ -44,17 +42,14 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
String data = event.data();
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
moveToFirstToken(jsonParser);
ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);

var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);

return chunk.choices()
.stream()
.map(choice -> choice.delta())
.filter(Objects::nonNull)
.map(delta -> delta.content())
.filter(content -> content != null && content.isEmpty() == false)
.filter(content -> Strings.isNullOrEmpty(content) == false)
.map(StreamingChatCompletionResults.Result::new);

} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

package org.elasticsearch.xpack.inference.services.googlevertexai.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
Expand All @@ -18,17 +16,16 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;

Expand All @@ -41,10 +38,13 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor

private final ServiceComponents serviceComponents;

static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler(
"Google VertexAI completion",
GoogleVertexAiCompletionResponseEntity::fromResponse,
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
true
);
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiChatCompletionResponseHandler("Google VertexAI completion");

static final String USER_ROLE = "user";

public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -73,36 +73,16 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje

@Override
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);

var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
UNIFIED_CHAT_COMPLETION_HANDLER,
CHAT_COMPLETION_HANDLER,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}

@Override
public ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);

var manager = new GenericRequestManager<>(serviceComponents.threadPool(), model, CHAT_COMPLETION_HANDLER, inputs -> {
try {
model.updateUri(inputs.stream());
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Error constructing URI for Google VertexAI completion",
RestStatus.INTERNAL_SERVER_ERROR,
e
);
}
return new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model);
}, ChatCompletionInput.class);

return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;

Expand All @@ -23,5 +22,4 @@ public interface GoogleVertexAiActionVisitor {

ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel {

private final URI streamingURI;

public GoogleVertexAiChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
Expand Down Expand Up @@ -63,7 +66,8 @@ public GoogleVertexAiChatCompletionModel(
serviceSettings
);
try {
this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
this.streamingURI = buildUriStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
this.nonStreamingUri = buildUriNonStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -114,7 +118,28 @@ public GoogleVertexAiSecretSettings getSecretSettings() {
return (GoogleVertexAiSecretSettings) super.getSecretSettings();
}

public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
public URI streamingURI() {
return this.streamingURI;
}

public static URI buildUriNonStreaming(String location, String projectId, String model) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
.setPathSegments(
GoogleVertexAiUtils.V1,
GoogleVertexAiUtils.PROJECTS,
projectId,
GoogleVertexAiUtils.LOCATIONS,
GoogleVertexAiUtils.GLOBAL,
GoogleVertexAiUtils.PUBLISHERS,
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
GoogleVertexAiUtils.MODELS,
format("%s:%s", model, GoogleVertexAiUtils.GENERATE_CONTENT)
)
.build();
}

public static URI buildUriStreaming(String location, String projectId, String model) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
.setPathSegments(
Expand All @@ -131,4 +156,25 @@ public static URI buildUri(String location, String projectId, String model) thro
.setCustomQuery(GoogleVertexAiUtils.QUERY_PARAM_ALT_SSE)
.build();
}

@Override
public int rateLimitGroupingHash() {
// In VertexAI rate limiting is scoped to the project, region, model and endpoint.
// API Key does not affect the quota
// https://ai.google.dev/gemini-api/docs/rate-limits
// https://cloud.google.com/vertex-ai/docs/quotas
var projectId = getServiceSettings().projectId();
var location = getServiceSettings().location();
var modelId = getServiceSettings().modelId();

// Since we don't beforehand know which API is going to be used, we take a conservative approach and
// count both endpoint for the rate limit
return Objects.hash(
projectId,
location,
modelId,
GoogleVertexAiUtils.GENERATE_CONTENT,
GoogleVertexAiUtils.STREAM_GENERATE_CONTENT
);
}
}

This file was deleted.

Loading
Loading