-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin #135701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin #135701
Changes from 6 commits
af60919
129f74e
52fbe36
f0e382a
4a2abaf
05a4c4f
8a47467
b75d352
41224bd
8cd7eb1
6c05a61
2d8ced7
ac65a3a
1824cb8
3353b85
767f4d3
8b25612
f66fb24
35ec9e8
d874fc9
f3a947e
ace886c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 135701 | ||
summary: Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; | ||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; | ||
|
@@ -28,6 +29,9 @@ | |
import org.elasticsearch.xpack.inference.services.googlevertexai.request.completion.GoogleVertexAiUnifiedChatCompletionRequest; | ||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; | ||
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; | ||
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler; | ||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; | ||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; | ||
|
||
import java.util.Map; | ||
import java.util.Objects; | ||
|
@@ -54,6 +58,28 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor | |
true | ||
); | ||
|
||
static final ResponseHandler GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER = new LlamaCompletionResponseHandler( | ||
"Google Model Garden Meta completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse | ||
); | ||
|
||
static final ResponseHandler GOOGLE_MODEL_GARDEN_HUGGING_FACE_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( | ||
"Google Model Garden Hugging Face completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse | ||
); | ||
|
||
static final ResponseHandler GOOGLE_MODEL_GARDEN_MISTRAL_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( | ||
"Google Model Garden Mistral completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse, | ||
ErrorResponse::fromResponse | ||
); | ||
|
||
static final ResponseHandler GOOGLE_MODEL_GARDEN_AI21_COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( | ||
"Google Model Garden AI21 completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse, | ||
ErrorResponse::fromResponse | ||
); | ||
|
||
static final String USER_ROLE = "user"; | ||
|
||
public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { | ||
|
@@ -91,11 +117,23 @@ public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<Stri | |
|
||
private GenericRequestManager<ChatCompletionInput> createRequestManager(GoogleVertexAiChatCompletionModel model) { | ||
switch (model.getServiceSettings().provider()) { | ||
case GOOGLE -> { | ||
|
||
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_COMPLETION_HANDLER); | ||
} | ||
case ANTHROPIC -> { | ||
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_ANTHROPIC_COMPLETION_HANDLER); | ||
} | ||
case GOOGLE -> { | ||
return createRequestManagerWithHandler(model, GOOGLE_VERTEX_AI_COMPLETION_HANDLER); | ||
case META -> { | ||
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER); | ||
} | ||
case HUGGING_FACE -> { | ||
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_HUGGING_FACE_COMPLETION_HANDLER); | ||
} | ||
case MISTRAL -> { | ||
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_MISTRAL_COMPLETION_HANDLER); | ||
} | ||
case AI21 -> { | ||
return createRequestManagerWithHandler(model, GOOGLE_MODEL_GARDEN_AI21_COMPLETION_HANDLER); | ||
} | ||
case null, default -> throw new ElasticsearchException( | ||
"Unsupported Google Model Garden provider: " + model.getServiceSettings().provider() | ||
|
@@ -104,14 +142,14 @@ private GenericRequestManager<ChatCompletionInput> createRequestManager(GoogleVe | |
} | ||
|
||
private GenericRequestManager<ChatCompletionInput> createRequestManagerWithHandler( | ||
GoogleVertexAiChatCompletionModel overriddenModel, | ||
GoogleVertexAiChatCompletionModel model, | ||
ResponseHandler responseHandler | ||
) { | ||
return new GenericRequestManager<>( | ||
serviceComponents.threadPool(), | ||
overriddenModel, | ||
model, | ||
responseHandler, | ||
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel), | ||
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), | ||
ChatCompletionInput.class | ||
); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
Value: {@link #maxCompletionTokens()}
. The existing Javadoc on line 113 in this file has a similar mistake.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.