Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
af60919
Integrate Google Model Garden providers for processing chat completio…
Jan-Kazlouski-elastic Sep 30, 2025
129f74e
Add changelog
Jan-Kazlouski-elastic Sep 30, 2025
52fbe36
[CI] Auto commit changes from spotless
Sep 30, 2025
f0e382a
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 1, 2025
4a2abaf
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 2, 2025
05a4c4f
Merge remote-tracking branch 'origin/feature/google-model-garden-open…
Jan-Kazlouski-elastic Oct 2, 2025
8a47467
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 6, 2025
b75d352
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 8, 2025
41224bd
Move model_id null check to UnifiedCompletionRequest, fix Javadoc, fi…
Jan-Kazlouski-elastic Oct 8, 2025
8cd7eb1
Refactor GoogleVertexAiUnifiedChatCompletionActionTests to simplify m…
Jan-Kazlouski-elastic Oct 8, 2025
6c05a61
Refactor GoogleModelGardenProvider to handle response handlers and re…
Jan-Kazlouski-elastic Oct 8, 2025
2d8ced7
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 8, 2025
ac65a3a
Fix typo in GoogleModelGardenProvider Google Model Garden AI21 chat c…
Jan-Kazlouski-elastic Oct 9, 2025
1824cb8
Merge remote-tracking branch 'origin/feature/google-model-garden-open…
Jan-Kazlouski-elastic Oct 9, 2025
3353b85
Refactor GoogleVertexAiUnifiedChatCompletionActionTests
Jan-Kazlouski-elastic Oct 9, 2025
767f4d3
Refactor GoogleModelGardenProvider
Jan-Kazlouski-elastic Oct 9, 2025
8b25612
Add Nullable annotation
Jan-Kazlouski-elastic Oct 9, 2025
f66fb24
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 9, 2025
35ec9e8
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 9, 2025
d874fc9
[CI] Auto commit changes from spotless
Oct 9, 2025
f3a947e
Fix Typo
Jan-Kazlouski-elastic Oct 9, 2025
ace886c
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/135701.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135701
summary: Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ public static Params withMaxTokens(String modelId, Params params) {
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
*/
public static Params withMaxTokens(Params params) {
return new DelegatingMapParams(Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)), params);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
Expand All @@ -116,6 +124,18 @@ public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Para
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be Value: {@link #maxCompletionTokens()}. The existing Javadoc on line 113 in this file has a similar mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(Params params) {
return new DelegatingMapParams(
Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD), Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())),
params
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId()))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,39 @@

package org.elasticsearch.xpack.inference.services.ai21.request;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel;

import java.io.IOException;
import java.util.Objects;

/**
* Ai21ChatCompletionRequestEntity is responsible for creating the request entity for Ai21 chat completion.
* It implements ToXContentObject to allow serialization to XContent format.
*/
public class Ai21ChatCompletionRequestEntity implements ToXContentObject {

private final Ai21ChatCompletionModel model;
private final String modelId;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;

public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, Ai21ChatCompletionModel model) {
public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.modelId = modelId;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params));
if (modelId != null) {
// Some models require the model ID to be specified in the request body
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params));
} else {
// Some models do not require the model ID to be specified in the request body
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(params));
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
*/
public enum GoogleModelGardenProvider {
GOOGLE,
ANTHROPIC;
ANTHROPIC,
META,
HUGGING_FACE,
MISTRAL,
AI21;

public static final String NAME = "google_model_garden_provider";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.MistralUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.EnumSet;
Expand Down Expand Up @@ -96,10 +101,30 @@ public class GoogleVertexAiService extends SenderService implements RerankingInf
"Google Vertex AI chat completion"
);

public static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
private static final ResponseHandler ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
"Google Model Garden Anthropic chat completion"
);

private static final ResponseHandler META_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler(
"Google Model Garden Meta chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler HUGGING_FACE_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
"Google Model Garden Hugging Face chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler MISTRAL_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
"Google Model Garden Mistral chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

private static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler(
"Google Model Garden Ai21 chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
Expand Down Expand Up @@ -265,20 +290,43 @@ protected void doUnifiedCompletionInfer(
}
}

/**
* Create the request manager based on the provider specified in the model's service settings.
* @param model The GoogleVertexAiChatCompletionModel containing the provider information.
* @return A GenericRequestManager configured with the appropriate response handler.
*/
private GenericRequestManager<UnifiedChatInput> createRequestManager(GoogleVertexAiChatCompletionModel model) {
switch (model.getServiceSettings().provider()) {
case ANTHROPIC -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER);
}
case GOOGLE -> {
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER);
}
case ANTHROPIC -> {
return createRequestManagerWithHandler(model, ANTHROPIC_CHAT_COMPLETION_HANDLER);
}
case META -> {
return createRequestManagerWithHandler(model, META_CHAT_COMPLETION_HANDLER);
}
case HUGGING_FACE -> {
return createRequestManagerWithHandler(model, HUGGING_FACE_CHAT_COMPLETION_HANDLER);
}
case MISTRAL -> {
return createRequestManagerWithHandler(model, MISTRAL_CHAT_COMPLETION_HANDLER);
}
case AI21 -> {
return createRequestManagerWithHandler(model, AI21_CHAT_COMPLETION_HANDLER);
}
case null, default -> throw new ElasticsearchException(
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider()
);
}
}

/**
* Helper method to create a GenericRequestManager with a specified response handler.
* @param model The GoogleVertexAiChatCompletionModel to be used for requests.
* @param responseHandler The ResponseHandler to process the responses.
* @return A GenericRequestManager configured with the provided response handler.
*/
private GenericRequestManager<UnifiedChatInput> createRequestManagerWithHandler(
GoogleVertexAiChatCompletionModel model,
ResponseHandler responseHandler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
Expand All @@ -28,6 +29,9 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;

import java.util.Map;
import java.util.Objects;
Expand All @@ -54,6 +58,28 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor
true
);

static final ResponseHandler GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER = new LlamaCompletionResponseHandler(
"Google Model Garden Meta completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

static final ResponseHandler GOOGLE_MODEL_GARDEN_HUGGING_FACE_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden Hugging Face completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

static final ResponseHandler GOOGLE_MODEL_GARDEN_MISTRAL_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden Mistral completion",
OpenAiChatCompletionResponseEntity::fromResponse,
ErrorResponse::fromResponse
);

static final ResponseHandler GOOGLE_MODEL_GARDEN_AI21_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden AI21 completion",
OpenAiChatCompletionResponseEntity::fromResponse,
ErrorResponse::fromResponse
);

static final String USER_ROLE = "user";

public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -91,11 +117,23 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri

private GenericRequestManager<ChatCompletionInput> createRequestManager(GoogleVertexAiChatCompletionModel model) {
switch (model.getServiceSettings().provider()) {
case GOOGLE -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we try to reduce these switch statements to a map. I'm thinking we could create a Provider class or something that has a static map of provider enums to an internal class that contains functions to constructor the 3 different things that we need these switch cases for.

Ideally we wouldn't need to pass an instance of the Provider class around. If possible we could do something like Provider.createCompletionRequestManager(provider, model)

That would grab the internal class and then call `createRequestManagerWithHandler(model, handler) with the static handler that is appropriate for that provider.

Provider would have a function for each of the switch states that we need.

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I had Ideas of moving this logic somewhere but initially decided not to.
Now I came up with an idea of moving the logic of getting response handlers and creating request entities to enum itself.
Of course I'm not a big fan of moving too much logic to enums, but having Provider class when we already have GoogleModelGardenProvider enum alongside is not ideal too IMHO. And that way when new provider is added - we can ideally localize the changes to a single file. To me - looks good. Please let me know what you think of it @jonathan-buttner

return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_COMPLETION_HANDLER);
}
case ANTHROPIC -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER);
}
case GOOGLE -> {
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_COMPLETION_HANDLER);
case META -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER);
}
case HUGGING_FACE -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_HUGGING_FACE_COMPLETION_HANDLER);
}
case MISTRAL -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_MISTRAL_COMPLETION_HANDLER);
}
case AI21 -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_AI21_COMPLETION_HANDLER);
}
case null, default -> throw new ElasticsearchException(
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider()
Expand All @@ -104,14 +142,14 @@ private GenericRequestManager<ChatCompletionInput> createRequestManager(GoogleVe
}

private GenericRequestManager<ChatCompletionInput> createRequestManagerWithHandler(
GoogleVertexAiChatCompletionModel overriddenModel,
GoogleVertexAiChatCompletionModel model,
ResponseHandler responseHandler
) {
return new GenericRequestManager<>(
serviceComponents.threadPool(),
overriddenModel,
model,
responseHandler,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel),
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ai21.request.Ai21ChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiRequest;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequestEntity;

import java.net.URI;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -53,19 +57,41 @@ public HttpRequest createHttpRequest() {
}

private ToXContentObject createRequestEntity() {
final var modelId = extractModelId();
switch (model.getServiceSettings().provider()) {
case GOOGLE -> {
return new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings().thinkingConfig());
}
case ANTHROPIC -> {
return new GoogleModelGardenAnthropicChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings());
}
case GOOGLE -> {
return new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getTaskSettings().thinkingConfig());
case META -> {
return new LlamaChatCompletionRequestEntity(unifiedChatInput, modelId);
}
case HUGGING_FACE -> {
return new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, modelId);
}
case MISTRAL -> {
return new MistralChatCompletionRequestEntity(unifiedChatInput, modelId);
}
case AI21 -> {
return new Ai21ChatCompletionRequestEntity(unifiedChatInput, modelId);
}
case null, default -> throw new ElasticsearchException(
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider()
);
}
}

/**
* Extracts the model ID to be used for the request. If the request contains a model ID, it is preferred.
* Otherwise, the model ID from the configuration is used.
* @return the model ID to be used for the request
*/
private String extractModelId() {
return unifiedChatInput.getRequest().model() != null ? unifiedChatInput.getRequest().model() : model.getServiceSettings().modelId();
}

public void decorateWithAuth(HttpPost httpPost) {
GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(getURI());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand Down
Loading