Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
af60919
Integrate Google Model Garden providers for processing chat completio…
Jan-Kazlouski-elastic Sep 30, 2025
129f74e
Add changelog
Jan-Kazlouski-elastic Sep 30, 2025
52fbe36
[CI] Auto commit changes from spotless
Sep 30, 2025
f0e382a
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 1, 2025
4a2abaf
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 2, 2025
05a4c4f
Merge remote-tracking branch 'origin/feature/google-model-garden-open…
Jan-Kazlouski-elastic Oct 2, 2025
8a47467
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 6, 2025
b75d352
Merge remote-tracking branch 'origin/main' into feature/google-model-…
Jan-Kazlouski-elastic Oct 8, 2025
41224bd
Move model_id null check to UnifiedCompletionRequest, fix Javadoc, fi…
Jan-Kazlouski-elastic Oct 8, 2025
8cd7eb1
Refactor GoogleVertexAiUnifiedChatCompletionActionTests to simplify m…
Jan-Kazlouski-elastic Oct 8, 2025
6c05a61
Refactor GoogleModelGardenProvider to handle response handlers and re…
Jan-Kazlouski-elastic Oct 8, 2025
2d8ced7
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 8, 2025
ac65a3a
Fix typo in GoogleModelGardenProvider Google Model Garden AI21 chat c…
Jan-Kazlouski-elastic Oct 9, 2025
1824cb8
Merge remote-tracking branch 'origin/feature/google-model-garden-open…
Jan-Kazlouski-elastic Oct 9, 2025
3353b85
Refactor GoogleVertexAiUnifiedChatCompletionActionTests
Jan-Kazlouski-elastic Oct 9, 2025
767f4d3
Refactor GoogleModelGardenProvider
Jan-Kazlouski-elastic Oct 9, 2025
8b25612
Add Nullable annotation
Jan-Kazlouski-elastic Oct 9, 2025
f66fb24
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 9, 2025
35ec9e8
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 9, 2025
d874fc9
[CI] Auto commit changes from spotless
Oct 9, 2025
f3a947e
Fix Typo
Jan-Kazlouski-elastic Oct 9, 2025
ace886c
Merge branch 'main' into feature/google-model-garden-openai-providers…
Jan-Kazlouski-elastic Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/135701.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135701
summary: Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,34 @@ public record UnifiedCompletionRequest(

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
*/
public static Params withMaxTokens(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)),
params
);
public static Params withMaxTokens(@Nullable String modelId, Params params) {
Map<String, String> entries = modelId != null
? Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD))
: Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD));
return new DelegatingMapParams(entries, params);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(
public static Params withMaxTokensAndSkipStreamOptionsField(@Nullable String modelId, Params params) {
Map<String, String> entries = modelId != null
? Map.ofEntries(
Map.entry(MODEL_ID_PARAM, modelId),
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
),
params
);
)
: Map.ofEntries(
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
);
return new DelegatingMapParams(entries, params);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId()))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,33 @@

package org.elasticsearch.xpack.inference.services.ai21.request;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel;

import java.io.IOException;
import java.util.Objects;

/**
* Ai21ChatCompletionRequestEntity is responsible for creating the request entity for Ai21 chat completion.
* It implements ToXContentObject to allow serialization to XContent format.
*/
public class Ai21ChatCompletionRequestEntity implements ToXContentObject {

private final Ai21ChatCompletionModel model;
private final String modelId;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;

public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, Ai21ChatCompletionModel model) {
public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.modelId = modelId;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params));
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params));
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,178 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.ai21.request.Ai21ChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicResponseHandler;
import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleModelGardenAnthropicChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.mistral.MistralUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;

import java.util.Locale;

/**
* Enum representing the supported model garden providers.
*/
public enum GoogleModelGardenProvider {
GOOGLE,
ANTHROPIC;
GOOGLE(
CompletionResponseHandlerHolder.GOOGLE_VERTEX_AI_COMPLETION_HANDLER,
ChatCompletionResponseHandlerHolder.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER,
(unifiedChatInput, modelId, taskSettings) -> new GoogleVertexAiUnifiedChatCompletionRequestEntity(
unifiedChatInput,
taskSettings.thinkingConfig()
)
),
ANTHROPIC(
CompletionResponseHandlerHolder.ANTHROPIC_COMPLETION_HANDLER,
ChatCompletionResponseHandlerHolder.ANTHROPIC_CHAT_COMPLETION_HANDLER,
(unifiedChatInput, modelId, taskSettings) -> new GoogleModelGardenAnthropicChatCompletionRequestEntity(
unifiedChatInput,
taskSettings
)
),
META(
CompletionResponseHandlerHolder.META_COMPLETION_HANDLER,
ChatCompletionResponseHandlerHolder.META_CHAT_COMPLETION_HANDLER,
(unifiedChatInput, modelId, taskSettings) -> new LlamaChatCompletionRequestEntity(unifiedChatInput, modelId)
),
HUGGING_FACE(
CompletionResponseHandlerHolder.HUGGING_FACE_COMPLETION_HANDLER,
ChatCompletionResponseHandlerHolder.HUGGING_FACE_CHAT_COMPLETION_HANDLER,
(unifiedChatInput, modelId, taskSettings) -> new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, modelId)
),
MISTRAL(
CompletionResponseHandlerHolder.MISTRAL_COMPLETION_HANDLER,
ChatCompletionResponseHandlerHolder.MISTRAL_CHAT_COMPLETION_HANDLER,
(unifiedChatInput, modelId, taskSettings) -> new MistralChatCompletionRequestEntity(unifiedChatInput, modelId)
),
AI21(
CompletionResponseHandlerHolder.AI21_COMPLETION_HANDLER,
ChatCompletionResponseHandlerHolder.AI21_CHAT_COMPLETION_HANDLER,
(unifiedChatInput, modelId, taskSettings) -> new Ai21ChatCompletionRequestEntity(unifiedChatInput, modelId)
);

private final ResponseHandler completionResponseHandler;
private final ResponseHandler chatCompletionResponseHandler;
private final RequestEntityCreator entityCreator;

GoogleModelGardenProvider(
ResponseHandler completionResponseHandler,
ResponseHandler chatCompletionResponseHandler,
RequestEntityCreator entityCreator
) {
this.completionResponseHandler = completionResponseHandler;
this.chatCompletionResponseHandler = chatCompletionResponseHandler;
this.entityCreator = entityCreator;
}

public ResponseHandler getCompletionResponseHandler() {
return completionResponseHandler;
}

public ResponseHandler getChatCompletionResponseHandler() {
return chatCompletionResponseHandler;
}

public ToXContentObject createRequestEntity(
UnifiedChatInput unifiedChatInput,
@Nullable String modelId,
GoogleVertexAiChatCompletionTaskSettings taskSettings
) {
return entityCreator.create(unifiedChatInput, modelId, taskSettings);
}

private static class CompletionResponseHandlerHolder {
static final ResponseHandler GOOGLE_VERTEX_AI_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler(
"Google Vertex AI completion",
GoogleVertexAiCompletionResponseEntity::fromResponse,
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
true
);

static final ResponseHandler ANTHROPIC_COMPLETION_HANDLER = new AnthropicResponseHandler(
"Google Model Garden Anthropic completion",
AnthropicChatCompletionResponseEntity::fromResponse,
true
);

static final ResponseHandler META_COMPLETION_HANDLER = new LlamaCompletionResponseHandler(
"Google Model Garden Meta completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

public static final String NAME = "google_model_garden_provider";
static final ResponseHandler HUGGING_FACE_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden Hugging Face completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

static final ResponseHandler MISTRAL_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden Mistral completion",
OpenAiChatCompletionResponseEntity::fromResponse,
ErrorResponse::fromResponse
);

static final ResponseHandler AI21_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"Google Model Garden AI21 completion",
OpenAiChatCompletionResponseEntity::fromResponse,
ErrorResponse::fromResponse
);
}

private static class ChatCompletionResponseHandlerHolder {
static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google Vertex AI chat completion"
);

static final ResponseHandler ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
"Google Model Garden Anthropic chat completion"
);

static final ResponseHandler META_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler(
"Google Model Garden Meta chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

static final ResponseHandler HUGGING_FACE_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
"Google Model Garden Hugging Face chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

static final ResponseHandler MISTRAL_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
"Google Model Garden Mistral chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler(
"Google Model Garden AI21 chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);
}

@FunctionalInterface
private interface RequestEntityCreator {
ToXContentObject create(
UnifiedChatInput unifiedChatInput,
@Nullable String modelId,
GoogleVertexAiChatCompletionTaskSettings taskSettings
);
}

public static GoogleModelGardenProvider fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -43,7 +42,6 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
Expand Down Expand Up @@ -92,14 +90,6 @@ public class GoogleVertexAiService extends SenderService implements RerankingInf
InputType.INTERNAL_SEARCH
);

public static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google Vertex AI chat completion"
);

public static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
"Google Model Garden Anthropic chat completion"
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
Expand Down Expand Up @@ -265,28 +255,16 @@ protected void doUnifiedCompletionInfer(
}
}

/**
* Helper method to create a GenericRequestManager with a specified response handler.
* @param model The GoogleVertexAiChatCompletionModel to be used for requests.
* @return A GenericRequestManager configured with the provided response handler.
*/
private GenericRequestManager<UnifiedChatInput> createRequestManager(GoogleVertexAiChatCompletionModel model) {
switch (model.getServiceSettings().provider()) {
case ANTHROPIC -> {
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER);
}
case GOOGLE -> {
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER);
}
case null, default -> throw new ElasticsearchException(
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider()
);
}
}

private GenericRequestManager<UnifiedChatInput> createRequestManagerWithHandler(
GoogleVertexAiChatCompletionModel model,
ResponseHandler responseHandler
) {
return new GenericRequestManager<>(
getServiceComponents().threadPool(),
model,
responseHandler,
model.getServiceSettings().provider().getChatCompletionResponseHandler(),
unifiedChatInput -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, model),
UnifiedChatInput.class
);
Expand Down
Loading