Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
af60919
Integrate Google Model Garden providers for processing chat completio…
Jan-Kazlouski-elastic Sep 30, 2025
129f74e
Add changelog
Jan-Kazlouski-elastic Sep 30, 2025
52fbe36
[CI] Auto commit changes from spotless
Sep 30, 2025
f0e382a
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 1, 2025
4a2abaf
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 2, 2025
05a4c4f
Merge remote-tracking branch 'origin/feature/google-model-garden-open…
Jan-Kazlouski-elastic Oct 2, 2025
8a47467
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 6, 2025
b75d352
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 8, 2025
41224bd
Move model_id null check to UnifiedCompletionRequest, fix Javadoc, fi…
Jan-Kazlouski-elastic Oct 8, 2025
8cd7eb1
Refactor GoogleVertexAiUnifiedChatCompletionActionTests to simplify m…
Jan-Kazlouski-elastic Oct 8, 2025
6c05a61
Refactor GoogleModelGardenProvider to handle response handlers and re…
Jan-Kazlouski-elastic Oct 8, 2025
2d8ced7
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 8, 2025
ac65a3a
Fix typo in GoogleModelGardenProvider Google Model Garden AI21 chat c…
Jan-Kazlouski-elastic Oct 9, 2025
1824cb8
Merge remote-tracking branch 'origin/feature/google-model-garden-open…
Jan-Kazlouski-elastic Oct 9, 2025
3353b85
Refactor GoogleVertexAiUnifiedChatCompletionActionTests
Jan-Kazlouski-elastic Oct 9, 2025
767f4d3
Refactor GoogleModelGardenProvider
Jan-Kazlouski-elastic Oct 9, 2025
8b25612
Add Nullable annotation
Jan-Kazlouski-elastic Oct 9, 2025
f66fb24
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 9, 2025
35ec9e8
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 9, 2025
d874fc9
[CI] Auto commit changes from spotless
Oct 9, 2025
f3a947e
Fix Typo
Jan-Kazlouski-elastic Oct 9, 2025
ace886c
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,51 +89,34 @@ public record UnifiedCompletionRequest(

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
*/
public static Params withMaxTokens(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)),
params
);
public static Params withMaxTokens(@Nullable String modelId, Params params) {
Map<String, String> entries = modelId != null
? Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD))
: Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD));
return new DelegatingMapParams(entries, params);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
*/
public static Params withMaxTokens(Params params) {
return new DelegatingMapParams(Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)), params);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(
public static Params withMaxTokensAndSkipStreamOptionsField(@Nullable String modelId, Params params) {
Map<String, String> entries = modelId != null
? Map.ofEntries(
Map.entry(MODEL_ID_PARAM, modelId),
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
),
params
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(Params params) {
return new DelegatingMapParams(
Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD), Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())),
params
);
)
: Map.ofEntries(
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
);
return new DelegatingMapParams(entries, params);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,7 @@ public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nulla
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (modelId != null) {
// Some models require the model ID to be specified in the request body
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params));
} else {
// Some models do not require the model ID to be specified in the request body
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(params));
}
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params));
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.ai21.request.Ai21ChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicResponseHandler;
import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleModelGardenAnthropicChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.mistral.MistralUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;

import java.util.Locale;

/**
Expand All @@ -20,7 +43,120 @@ public enum GoogleModelGardenProvider {
MISTRAL,
AI21;

public static final String NAME = "google_model_garden_provider";
private static final ResponseHandler GOOGLE_VERTEX_AI_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler(
"Google Vertex AI completion",
GoogleVertexAiCompletionResponseEntity::fromResponse,
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
true
);

private static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER = new AnthropicResponseHandler(
"Google Model Garden Anthropic completion",
AnthropicChatCompletionResponseEntity::fromResponse,
true
);

private static final ResponseHandler GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER = new LlamaCompletionResponseHandler(
"Google Model Garden Meta completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler GOOGLE_MODEL_GARDEN_HUGGING_FACE_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden Hugging Face completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler GOOGLE_MODEL_GARDEN_MISTRAL_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden Mistral completion",
OpenAiChatCompletionResponseEntity::fromResponse,
ErrorResponse::fromResponse
);

private static final ResponseHandler GOOGLE_MODEL_GARDEN_AI21_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden AI21 completion",
OpenAiChatCompletionResponseEntity::fromResponse,
ErrorResponse::fromResponse
);

private static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google Vertex AI chat completion"
);

private static final ResponseHandler ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
"Google Model Garden Anthropic chat completion"
);

private static final ResponseHandler META_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler(
"Google Model Garden Meta chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler HUGGING_FACE_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
"Google Model Garden Hugging Face chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler MISTRAL_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
"Google Model Garden Mistral chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler(
"Google Model Garden Ai21 chat completions",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nitpick, only fix this if there are other changes required, but this should be AI21, with an uppercase i

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this should be AI21, with an uppercase i
Good catch. Confused it with Ai2 which has Ai with lowercase i. Fixed now.

OpenAiChatCompletionResponseEntity::fromResponse
);

/**
* Gets the completion response handler for the model garden provider.
* @return the ResponseHandler associated with the provider
*/
public ResponseHandler getCompletionResponseHandler() {
return switch (this) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the switches. I think we can do this by having a package private constructor. Something like:

@FunctionalInterface
private interface RequestEntityCreator {
  ToXContentObject create(UnifiedChatInput unifiedChatInput,
        String modelId,
        GoogleVertexAiChatCompletionTaskSettings taskSettings);
}

private final ResponseHandler completionResponseHandler;
private final ResponseHandler chatCompletionResponseHandler;
private final RequestEntityCreator entityCreator;

public enum GoogleModelGardenProvider {
    GOOGLE(GOOGLE_VERTEX_AI_COMPLETION_HANDLER, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER, (unifiedInput, modelId, taskSettings) -> new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, taskSettings.thinkingConfig()),
...
}

GoogleModelGardenProvider(ResponseHandler a, ResponseHandler b, RequestEntityCreator entityCreator) { ... }

Then the methods in here just return/call the appropriate methods and we won't need the switches anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Had to move handlers into nested classes because you're not allowed read the value of a field before its definition and you cannot declare any fields before actual enum constants.

case GOOGLE -> GOOGLE_VERTEX_AI_COMPLETION_HANDLER;
case ANTHROPIC -> GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER;
case META -> GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER;
case HUGGING_FACE -> GOOGLE_MODEL_GARDEN_HUGGING_FACE_COMPLETION_HANDLER;
case MISTRAL -> GOOGLE_MODEL_GARDEN_MISTRAL_COMPLETION_HANDLER;
case AI21 -> GOOGLE_MODEL_GARDEN_AI21_COMPLETION_HANDLER;
};
}

/**
* Gets the chat completion response handler for the model garden provider.
* @return the ResponseHandler associated with the provider
*/
public ResponseHandler getChatCompletionResponseHandler() {
return switch (this) {
case GOOGLE -> GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER;
case ANTHROPIC -> ANTHROPIC_CHAT_COMPLETION_HANDLER;
case META -> META_CHAT_COMPLETION_HANDLER;
case HUGGING_FACE -> HUGGING_FACE_CHAT_COMPLETION_HANDLER;
case MISTRAL -> MISTRAL_CHAT_COMPLETION_HANDLER;
case AI21 -> AI21_CHAT_COMPLETION_HANDLER;
};
}

/**
* Creates the request entity for the model garden provider based on the unified chat input and model ID.
* @param unifiedChatInput the unified chat input containing messages and parameters for the chat completion request
* @param modelId the model ID to be used for the request
* @param taskSettings the task settings specific to Google Vertex AI chat completion
* @return a ToXContentObject representing the request entity for the provider
*/
public ToXContentObject createRequestEntity(
UnifiedChatInput unifiedChatInput,
String modelId,
GoogleVertexAiChatCompletionTaskSettings taskSettings
) {
return switch (this) {
case GOOGLE -> new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, taskSettings.thinkingConfig());
case ANTHROPIC -> new GoogleModelGardenAnthropicChatCompletionRequestEntity(unifiedChatInput, taskSettings);
case META -> new LlamaChatCompletionRequestEntity(unifiedChatInput, modelId);
case HUGGING_FACE -> new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, modelId);
case MISTRAL -> new MistralChatCompletionRequestEntity(unifiedChatInput, modelId);
case AI21 -> new Ai21ChatCompletionRequestEntity(unifiedChatInput, modelId);
};
}

public static GoogleModelGardenProvider fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -43,18 +42,12 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.MistralUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.EnumSet;
Expand Down Expand Up @@ -97,34 +90,6 @@ public class GoogleVertexAiService extends SenderService implements RerankingInf
InputType.INTERNAL_SEARCH
);

public static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google Vertex AI chat completion"
);

private static final ResponseHandler ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
"Google Model Garden Anthropic chat completion"
);

private static final ResponseHandler META_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler(
"Google Model Garden Meta chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler HUGGING_FACE_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
"Google Model Garden Hugging Face chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler MISTRAL_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
"Google Model Garden Mistral chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler(
"Google Model Garden Ai21 chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
Expand Down Expand Up @@ -290,51 +255,16 @@ protected void doUnifiedCompletionInfer(
}
}

/**
* Create the request manager based on the provider specified in the model's service settings.
* @param model The GoogleVertexAiChatCompletionModel containing the provider information.
* @return A GenericRequestManager configured with the appropriate response handler.
*/
private GenericRequestManager<UnifiedChatInput> createRequestManager(GoogleVertexAiChatCompletionModel model) {
switch (model.getServiceSettings().provider()) {
case GOOGLE -> {
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER);
}
case ANTHROPIC -> {
return createRequestManagerWithHandler(model, ANTHROPIC_CHAT_COMPLETION_HANDLER);
}
case META -> {
return createRequestManagerWithHandler(model, META_CHAT_COMPLETION_HANDLER);
}
case HUGGING_FACE -> {
return createRequestManagerWithHandler(model, HUGGING_FACE_CHAT_COMPLETION_HANDLER);
}
case MISTRAL -> {
return createRequestManagerWithHandler(model, MISTRAL_CHAT_COMPLETION_HANDLER);
}
case AI21 -> {
return createRequestManagerWithHandler(model, AI21_CHAT_COMPLETION_HANDLER);
}
case null, default -> throw new ElasticsearchException(
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider()
);
}
}

/**
* Helper method to create a GenericRequestManager with a specified response handler.
* @param model The GoogleVertexAiChatCompletionModel to be used for requests.
* @param responseHandler The ResponseHandler to process the responses.
* @return A GenericRequestManager configured with the provided response handler.
*/
private GenericRequestManager<UnifiedChatInput> createRequestManagerWithHandler(
GoogleVertexAiChatCompletionModel model,
ResponseHandler responseHandler
) {
private GenericRequestManager<UnifiedChatInput> createRequestManager(GoogleVertexAiChatCompletionModel model) {
return new GenericRequestManager<>(
getServiceComponents().threadPool(),
model,
responseHandler,
model.getServiceSettings().provider().getChatCompletionResponseHandler(),
unifiedChatInput -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, model),
UnifiedChatInput.class
);
Expand Down
Loading