Skip to content

Commit d4cb9ff

Browse files
Jan-Kazlouski-elasticelasticsearchmachine
andauthored
Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin (#135701)
* Integrate Google Model Garden providers for processing chat completion requests - Refactor Mistral, Ai21, Llama, and Hugging Face request entities to accept model IDs. - Update GoogleVertexAiActionCreator to handle multiple providers including META, HUGGING_FACE, MISTRAL, and AI21. - Enhance serialization tests for model ID handling in chat completion requests. - Introduce new response handlers for each Google Model Garden provider. * Add changelog * [CI] Auto commit changes from spotless * Move model_id null check to UnifiedCompletionRequest, fix Javadoc, fix unit tests, change remove optional stream_options field from llama requests * Refactor GoogleVertexAiUnifiedChatCompletionActionTests to simplify method signatures and improve readability * Refactor GoogleModelGardenProvider to handle response handlers and request entities * Fix typo in GoogleModelGardenProvider Google Model Garden AI21 chat completions response handler * Refactor GoogleVertexAiUnifiedChatCompletionActionTests * Refactor GoogleModelGardenProvider * Add Nullable annotation * [CI] Auto commit changes from spotless * Fix Typo --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent afad2a5 commit d4cb9ff

22 files changed

+624
-347
lines changed

docs/changelog/135701.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 135701
2+
summary: Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,31 +89,34 @@ public record UnifiedCompletionRequest(
8989

9090
/**
9191
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
92-
* - Key: {@link #MODEL_FIELD}, Value: modelId
92+
* - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null
9393
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
9494
*/
95-
public static Params withMaxTokens(String modelId, Params params) {
96-
return new DelegatingMapParams(
97-
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)),
98-
params
99-
);
95+
public static Params withMaxTokens(@Nullable String modelId, Params params) {
96+
Map<String, String> entries = modelId != null
97+
? Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD))
98+
: Map.ofEntries(Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD));
99+
return new DelegatingMapParams(entries, params);
100100
}
101101

102102
/**
103103
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
104-
* - Key: {@link #MODEL_FIELD}, Value: modelId
105-
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
104+
* - Key: {@link #MODEL_FIELD}, Value: modelId, if modelId is not null
105+
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
106106
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
107107
*/
108-
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
109-
return new DelegatingMapParams(
110-
Map.ofEntries(
108+
public static Params withMaxTokensAndSkipStreamOptionsField(@Nullable String modelId, Params params) {
109+
Map<String, String> entries = modelId != null
110+
? Map.ofEntries(
111111
Map.entry(MODEL_ID_PARAM, modelId),
112112
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
113113
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
114-
),
115-
params
116-
);
114+
)
115+
: Map.ofEntries(
116+
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
117+
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
118+
);
119+
return new DelegatingMapParams(entries, params);
117120
}
118121

119122
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ public HttpRequest createHttpRequest() {
4343
HttpPost httpPost = new HttpPost(model.uri());
4444

4545
ByteArrayEntity byteEntity = new ByteArrayEntity(
46-
Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
46+
Strings.toString(new Ai21ChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId()))
47+
.getBytes(StandardCharsets.UTF_8)
4748
);
4849
httpPost.setEntity(byteEntity);
4950

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/request/Ai21ChatCompletionRequestEntity.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,33 @@
77

88
package org.elasticsearch.xpack.inference.services.ai21.request;
99

10+
import org.elasticsearch.core.Nullable;
1011
import org.elasticsearch.inference.UnifiedCompletionRequest;
1112
import org.elasticsearch.xcontent.ToXContentObject;
1213
import org.elasticsearch.xcontent.XContentBuilder;
1314
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1415
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
15-
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel;
1616

1717
import java.io.IOException;
18-
import java.util.Objects;
1918

2019
/**
2120
* Ai21ChatCompletionRequestEntity is responsible for creating the request entity for Ai21 chat completion.
2221
* It implements ToXContentObject to allow serialization to XContent format.
2322
*/
2423
public class Ai21ChatCompletionRequestEntity implements ToXContentObject {
2524

26-
private final Ai21ChatCompletionModel model;
25+
private final String modelId;
2726
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
2827

29-
public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, Ai21ChatCompletionModel model) {
28+
public Ai21ChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, @Nullable String modelId) {
3029
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
31-
this.model = Objects.requireNonNull(model);
30+
this.modelId = modelId;
3231
}
3332

3433
@Override
3534
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3635
builder.startObject();
37-
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params));
36+
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params));
3837
builder.endObject();
3938
return builder;
4039
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleModelGardenProvider.java

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,178 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai;
99

10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
15+
import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler;
16+
import org.elasticsearch.xpack.inference.services.ai21.request.Ai21ChatCompletionRequestEntity;
17+
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
18+
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicResponseHandler;
19+
import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings;
21+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleModelGardenAnthropicChatCompletionRequestEntity;
22+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequestEntity;
23+
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
24+
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler;
25+
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
26+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
27+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler;
28+
import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequestEntity;
29+
import org.elasticsearch.xpack.inference.services.mistral.MistralUnifiedChatCompletionResponseHandler;
30+
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequestEntity;
31+
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
32+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
33+
1034
import java.util.Locale;
1135

1236
/**
1337
* Enum representing the supported model garden providers.
1438
*/
1539
public enum GoogleModelGardenProvider {
16-
GOOGLE,
17-
ANTHROPIC;
40+
GOOGLE(
41+
CompletionResponseHandlerHolder.GOOGLE_VERTEX_AI_COMPLETION_HANDLER,
42+
ChatCompletionResponseHandlerHolder.GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER,
43+
(unifiedChatInput, modelId, taskSettings) -> new GoogleVertexAiUnifiedChatCompletionRequestEntity(
44+
unifiedChatInput,
45+
taskSettings.thinkingConfig()
46+
)
47+
),
48+
ANTHROPIC(
49+
CompletionResponseHandlerHolder.ANTHROPIC_COMPLETION_HANDLER,
50+
ChatCompletionResponseHandlerHolder.ANTHROPIC_CHAT_COMPLETION_HANDLER,
51+
(unifiedChatInput, modelId, taskSettings) -> new GoogleModelGardenAnthropicChatCompletionRequestEntity(
52+
unifiedChatInput,
53+
taskSettings
54+
)
55+
),
56+
META(
57+
CompletionResponseHandlerHolder.META_COMPLETION_HANDLER,
58+
ChatCompletionResponseHandlerHolder.META_CHAT_COMPLETION_HANDLER,
59+
(unifiedChatInput, modelId, taskSettings) -> new LlamaChatCompletionRequestEntity(unifiedChatInput, modelId)
60+
),
61+
HUGGING_FACE(
62+
CompletionResponseHandlerHolder.HUGGING_FACE_COMPLETION_HANDLER,
63+
ChatCompletionResponseHandlerHolder.HUGGING_FACE_CHAT_COMPLETION_HANDLER,
64+
(unifiedChatInput, modelId, taskSettings) -> new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, modelId)
65+
),
66+
MISTRAL(
67+
CompletionResponseHandlerHolder.MISTRAL_COMPLETION_HANDLER,
68+
ChatCompletionResponseHandlerHolder.MISTRAL_CHAT_COMPLETION_HANDLER,
69+
(unifiedChatInput, modelId, taskSettings) -> new MistralChatCompletionRequestEntity(unifiedChatInput, modelId)
70+
),
71+
AI21(
72+
CompletionResponseHandlerHolder.AI21_COMPLETION_HANDLER,
73+
ChatCompletionResponseHandlerHolder.AI21_CHAT_COMPLETION_HANDLER,
74+
(unifiedChatInput, modelId, taskSettings) -> new Ai21ChatCompletionRequestEntity(unifiedChatInput, modelId)
75+
);
76+
77+
private final ResponseHandler completionResponseHandler;
78+
private final ResponseHandler chatCompletionResponseHandler;
79+
private final RequestEntityCreator entityCreator;
80+
81+
GoogleModelGardenProvider(
82+
ResponseHandler completionResponseHandler,
83+
ResponseHandler chatCompletionResponseHandler,
84+
RequestEntityCreator entityCreator
85+
) {
86+
this.completionResponseHandler = completionResponseHandler;
87+
this.chatCompletionResponseHandler = chatCompletionResponseHandler;
88+
this.entityCreator = entityCreator;
89+
}
90+
91+
public ResponseHandler getCompletionResponseHandler() {
92+
return completionResponseHandler;
93+
}
94+
95+
public ResponseHandler getChatCompletionResponseHandler() {
96+
return chatCompletionResponseHandler;
97+
}
98+
99+
public ToXContentObject createRequestEntity(
100+
UnifiedChatInput unifiedChatInput,
101+
@Nullable String modelId,
102+
GoogleVertexAiChatCompletionTaskSettings taskSettings
103+
) {
104+
return entityCreator.create(unifiedChatInput, modelId, taskSettings);
105+
}
106+
107+
private static class CompletionResponseHandlerHolder {
108+
static final ResponseHandler GOOGLE_VERTEX_AI_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler(
109+
"Google Vertex AI completion",
110+
GoogleVertexAiCompletionResponseEntity::fromResponse,
111+
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
112+
true
113+
);
114+
115+
static final ResponseHandler ANTHROPIC_COMPLETION_HANDLER = new AnthropicResponseHandler(
116+
"Google Model Garden Anthropic completion",
117+
AnthropicChatCompletionResponseEntity::fromResponse,
118+
true
119+
);
120+
121+
static final ResponseHandler META_COMPLETION_HANDLER = new LlamaCompletionResponseHandler(
122+
"Google Model Garden Meta completion",
123+
OpenAiChatCompletionResponseEntity::fromResponse
124+
);
18125

19-
public static final String NAME = "google_model_garden_provider";
126+
static final ResponseHandler HUGGING_FACE_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
127+
"Google Model Garden Hugging Face completion",
128+
OpenAiChatCompletionResponseEntity::fromResponse
129+
);
130+
131+
static final ResponseHandler MISTRAL_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
132+
"Google Model Garden Mistral completion",
133+
OpenAiChatCompletionResponseEntity::fromResponse,
134+
ErrorResponse::fromResponse
135+
);
136+
137+
static final ResponseHandler AI21_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
138+
"Google Model Garden AI21 completion",
139+
OpenAiChatCompletionResponseEntity::fromResponse,
140+
ErrorResponse::fromResponse
141+
);
142+
}
143+
144+
private static class ChatCompletionResponseHandlerHolder {
145+
static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
146+
"Google Vertex AI chat completion"
147+
);
148+
149+
static final ResponseHandler ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
150+
"Google Model Garden Anthropic chat completion"
151+
);
152+
153+
static final ResponseHandler META_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler(
154+
"Google Model Garden Meta chat completion",
155+
OpenAiChatCompletionResponseEntity::fromResponse
156+
);
157+
158+
static final ResponseHandler HUGGING_FACE_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
159+
"Google Model Garden Hugging Face chat completion",
160+
OpenAiChatCompletionResponseEntity::fromResponse
161+
);
162+
163+
static final ResponseHandler MISTRAL_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
164+
"Google Model Garden Mistral chat completions",
165+
OpenAiChatCompletionResponseEntity::fromResponse
166+
);
167+
168+
static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler(
169+
"Google Model Garden AI21 chat completions",
170+
OpenAiChatCompletionResponseEntity::fromResponse
171+
);
172+
}
173+
174+
@FunctionalInterface
175+
private interface RequestEntityCreator {
176+
ToXContentObject create(
177+
UnifiedChatInput unifiedChatInput,
178+
@Nullable String modelId,
179+
GoogleVertexAiChatCompletionTaskSettings taskSettings
180+
);
181+
}
20182

21183
public static GoogleModelGardenProvider fromString(String name) {
22184
return valueOf(name.trim().toUpperCase(Locale.ROOT));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3434
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3535
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
36-
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3736
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3837
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3938
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -43,7 +42,6 @@
4342
import org.elasticsearch.xpack.inference.services.SenderService;
4443
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4544
import org.elasticsearch.xpack.inference.services.ServiceUtils;
46-
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicChatCompletionResponseHandler;
4745
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
4846
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
4947
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
@@ -92,14 +90,6 @@ public class GoogleVertexAiService extends SenderService implements RerankingInf
9290
InputType.INTERNAL_SEARCH
9391
);
9492

95-
public static final ResponseHandler GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
96-
"Google Vertex AI chat completion"
97-
);
98-
99-
public static final ResponseHandler GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER = new AnthropicChatCompletionResponseHandler(
100-
"Google Model Garden Anthropic chat completion"
101-
);
102-
10393
@Override
10494
public Set<TaskType> supportedStreamingTasks() {
10595
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
@@ -265,28 +255,16 @@ protected void doUnifiedCompletionInfer(
265255
}
266256
}
267257

258+
/**
259+
* Helper method to create a GenericRequestManager with a specified response handler.
260+
* @param model The GoogleVertexAiChatCompletionModel to be used for requests.
261+
* @return A GenericRequestManager configured with the provided response handler.
262+
*/
268263
private GenericRequestManager<UnifiedChatInput> createRequestManager(GoogleVertexAiChatCompletionModel model) {
269-
switch (model.getServiceSettings().provider()) {
270-
case ANTHROPIC -> {
271-
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_CHAT_COMPLETION_HANDLER);
272-
}
273-
case GOOGLE -> {
274-
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER);
275-
}
276-
case null, default -> throw new ElasticsearchException(
277-
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider()
278-
);
279-
}
280-
}
281-
282-
private GenericRequestManager<UnifiedChatInput> createRequestManagerWithHandler(
283-
GoogleVertexAiChatCompletionModel model,
284-
ResponseHandler responseHandler
285-
) {
286264
return new GenericRequestManager<>(
287265
getServiceComponents().threadPool(),
288266
model,
289-
responseHandler,
267+
model.getServiceSettings().provider().getChatCompletionResponseHandler(),
290268
unifiedChatInput -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, model),
291269
UnifiedChatInput.class
292270
);

0 commit comments

Comments
 (0)