-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin #135701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Google Model Garden's Meta, Mistral, Hugging Face and Ai21 providers support to Inference Plugin #135701
Conversation
…n requests - Refactor Mistral, Ai21, Llama, and Hugging Face request entities to accept model IDs. - Update GoogleVertexAiActionCreator to handle multiple providers including META, HUGGING_FACE, MISTRAL, and AI21. - Enhance serialization tests for model ID handling in chat completion requests. - Introduce new response handlers for each Google Model Garden provider.
…garden-openai-providers-integration
…garden-openai-providers-integration
…ai-providers-integration' into feature/google-model-garden-openai-providers-integration
Hugging Face (Dedicated Endpoint) Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
|
Hugging Face (Shared endpoint) Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
|
Mistral (Serverless) Tested with mistral-small-2503 Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
|
Mistral (Dedicated endpoint) Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
|
Ai21 (Serverless) Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
|
Testing of the providers is finished. Other configurations such as:
cannot be performed because endpoints do not return successful results. |
@jonathan-buttner @DonalEvans @dan-rubinstein |
Pinging @elastic/ml-core (Team:ML) |
|
||
/** | ||
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: | ||
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be Value: {@link #maxCompletionTokens()}
. The existing Javadoc on line 113 in this file has a similar mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
if (modelId != null) { | ||
// Standard Hugging Face models require a model ID | ||
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(modelId, params)); | ||
} else { | ||
// Some Hugging Face endpoints may not require a model ID | ||
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(params)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to move the null check into the withMaxTokens(String modelId, Params params)
method instead of having two different methods and requiring the caller know which one should be called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, let's have withMaxTokens(@Nullable String modelId, Params params)
and do the null check there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (modelId != null) { | ||
// Some Mistral endpoints require the model ID to be specified in the request body | ||
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(modelId, params)); | ||
} else { | ||
// Some Mistral endpoints do not require the model ID to be specified | ||
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(params)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another place where we could move the null check into the method on UnifiedCompletionRequest
instead of creating an extra method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
testExecute_ThrowsElasticsearchException( | ||
null, | ||
null, | ||
null, | ||
GoogleModelGardenProvider.META, | ||
new URI("http://localhost:9200"), | ||
GOOGLE_MODEL_GARDEN_META_COMPLETION_HANDLER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests in this class could be simplified a fair bit by removing the irrelevant parameters from the createAction()
method signature. projectId
, location
, modelId
and uri
can all be null in every test without affecting the behaviour of the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Removed irrelevant params.
GoogleModelGardenProvider googleModelGardenProvider, | ||
GoogleModelGardenProvider provider, | ||
URI uri, | ||
ResponseHandler googleModelGardenAnthropicCompletionHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name is a bit misleading, since in most of the tests, it's not specifically an Anthropic handler that's being passed in. This comment also applies to the similarly named parameter in testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed that. Renamed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, left a suggestion.
|
||
private GenericRequestManager<ChatCompletionInput> createRequestManager(GoogleVertexAiChatCompletionModel model) { | ||
switch (model.getServiceSettings().provider()) { | ||
case GOOGLE -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we try to reduce these switch statements to a map. I'm thinking we could create a Provider
class or something that has a static map of provider enums to an internal class that contains functions to constructor the 3 different things that we need these switch cases for.
Ideally we wouldn't need to pass an instance of the Provider
class around. If possible we could do something like Provider.createCompletionRequestManager(provider, model)
That would grab the internal class and then call `createRequestManagerWithHandler(model, handler) with the static handler that is appropriate for that provider.
Provider
would have a function for each of the switch states that we need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I had Ideas of moving this logic somewhere but initially decided not to.
Now I came up with an idea of moving the logic of getting response handlers and creating request entities to enum itself.
Of course I'm not a big fan of moving too much logic to enums, but having Provider class when we already have GoogleModelGardenProvider enum alongside is not ideal too IMHO. And that way when new provider is added - we can ideally localize the changes to a single file. To me - looks good. Please let me know what you think of it @jonathan-buttner
…garden-openai-providers-integration
…x unit tests, change remove optional stream_options field from llama requests
…ethod signatures and improve readability
@DonalEvans @jonathan-buttner |
); | ||
|
||
private static final ResponseHandler AI21_CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler( | ||
"Google Model Garden Ai21 chat completions", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nitpick, only fix this if there are other changes required, but this should be AI21
, with an uppercase i
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but this should be AI21, with an uppercase i
Good catch. Confused it with Ai2 which has Ai with lowercase i. Fixed now.
URI uri, | ||
ResponseHandler handler | ||
) { | ||
private ExecutableAction createAction(Sender sender, GoogleModelGardenProvider provider, URI uri, ResponseHandler handler) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The uri
method parameter is also irrelevant and can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, missed that. Removed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, one clarification for the switch statements.
* @return the ResponseHandler associated with the provider | ||
*/ | ||
public ResponseHandler getCompletionResponseHandler() { | ||
return switch (this) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the switches. I think we can do this by having a package private constructor. Something like:
@FunctionalInterface
private interface RequestEntityCreator {
ToXContentObject create(UnifiedChatInput unifiedChatInput,
String modelId,
GoogleVertexAiChatCompletionTaskSettings taskSettings);
}
private final ResponseHandler completionResponseHandler;
private final ResponseHandler chatCompletionResponseHandler;
private final RequestEntityCreator entityCreator;
public enum GoogleModelGardenProvider {
GOOGLE(GOOGLE_VERTEX_AI_COMPLETION_HANDLER, GOOGLE_VERTEX_AI_CHAT_COMPLETION_HANDLER, (unifiedInput, modelId, taskSettings) -> new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, taskSettings.thinkingConfig()),
...
}
GoogleModelGardenProvider(ResponseHandler a, ResponseHandler b, RequestEntityCreator entityCreator) { ... }
Then the methods in here just return/call the appropriate methods and we won't need the switches anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Had to move handlers into nested classes because you're not allowed read the value of a field before its definition and you cannot declare any fields before actual enum constants.
…ompletions response handler
…ai-providers-integration' into feature/google-model-garden-openai-providers-integration
@jonathan-buttner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes!
@jonathan-buttner Thanks for the approval! You can merge it when you are ready! |
@elasticmachine merge upstream |
There are no new commits on the base branch. |
Head branch was pushed to by a user without write access
@jonathan-buttner |
META (Serverless):
Tested with: meta/llama-3.3-70b-instruct-maas
Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
META (Dedicated Enpoint):
Tested with: meta/llama-3.3-70b-instruct-maas
Create Completion Endpoint
Perform Non-Streaming Completion
Perform Streaming Completion
Create Chat Completion Endpoint
Perform Basic Chat Completion
Perform Complex Chat Completion
Server Error is caused by the endpoint capabilities and not the integration itself.
gradle check
?