-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Inference API removing _unified and using _stream instead #121804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jonathan-buttner
merged 12 commits into
elastic:main
from
jonathan-buttner:ml-proxy-action
Feb 7, 2025
Merged
Changes from 10 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e99a781
Adding proxy action
jonathan-buttner 0c50d84
[CI] Auto commit changes from spotless
7d10123
Incrementing reference count for body content and fixing tests
jonathan-buttner a044c3f
Merge branch 'ml-proxy-action' of github.com:jonathan-buttner/elastic…
jonathan-buttner ab5cc0a
[CI] Auto commit changes from spotless
456ea68
Refactoring
jonathan-buttner 4df20b5
Merge branch 'ml-proxy-action' of github.com:jonathan-buttner/elastic…
jonathan-buttner 0d4b801
Merge branch 'main' into ml-proxy-action
jonathan-buttner 246b646
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
jonathan-buttner edf0d22
Addressing feedback
jonathan-buttner bd7f3d6
Merge branch 'main' into ml-proxy-action
jonathan-buttner 65cb320
Merge branch 'main' into ml-proxy-action
jonathan-buttner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
...ore/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.core.inference.action; | ||
|
|
||
| import org.elasticsearch.action.ActionRequest; | ||
| import org.elasticsearch.action.ActionRequestValidationException; | ||
| import org.elasticsearch.action.ActionType; | ||
| import org.elasticsearch.common.bytes.BytesReference; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.common.xcontent.XContentHelper; | ||
| import org.elasticsearch.core.TimeValue; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Objects; | ||
|
|
||
| /** | ||
| * This action is used when making a REST request to the inference API. The transport handler | ||
| * will then look at the task type in the params (or retrieve it from the persisted model if it wasn't | ||
| * included in the params) to determine where this request should be routed. If the task type is chat completion | ||
| * then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}. | ||
| * If not, it will be passed along to {@link InferenceAction}. | ||
| */ | ||
| public class InferenceActionProxy extends ActionType<InferenceAction.Response> { | ||
| public static final InferenceActionProxy INSTANCE = new InferenceActionProxy(); | ||
| public static final String NAME = "cluster:monitor/xpack/inference/post"; | ||
|
|
||
| public InferenceActionProxy() { | ||
| super(NAME); | ||
| } | ||
|
|
||
| public static class Request extends ActionRequest { | ||
|
|
||
| private final TaskType taskType; | ||
| private final String inferenceEntityId; | ||
| private final BytesReference content; | ||
| private final XContentType contentType; | ||
| private final TimeValue timeout; | ||
| private final boolean stream; | ||
|
|
||
| public Request( | ||
| TaskType taskType, | ||
| String inferenceEntityId, | ||
| BytesReference content, | ||
| XContentType contentType, | ||
| TimeValue timeout, | ||
| boolean stream | ||
| ) { | ||
| this.taskType = taskType; | ||
| this.inferenceEntityId = inferenceEntityId; | ||
| this.content = content; | ||
| this.contentType = contentType; | ||
| this.timeout = timeout; | ||
| this.stream = stream; | ||
| } | ||
|
|
||
| public Request(StreamInput in) throws IOException { | ||
| super(in); | ||
| this.taskType = TaskType.fromStream(in); | ||
| this.inferenceEntityId = in.readString(); | ||
| this.content = in.readBytesReference(); | ||
| this.contentType = in.readEnum(XContentType.class); | ||
| this.timeout = in.readTimeValue(); | ||
|
|
||
| // streaming is not supported yet for transport traffic | ||
| this.stream = false; | ||
| } | ||
|
|
||
| public TaskType getTaskType() { | ||
| return taskType; | ||
| } | ||
|
|
||
| public String getInferenceEntityId() { | ||
| return inferenceEntityId; | ||
| } | ||
|
|
||
| public BytesReference getContent() { | ||
| return content; | ||
| } | ||
|
|
||
| public XContentType getContentType() { | ||
| return contentType; | ||
| } | ||
|
|
||
| public TimeValue getTimeout() { | ||
| return timeout; | ||
| } | ||
|
|
||
| public boolean isStreaming() { | ||
| return stream; | ||
| } | ||
|
|
||
| @Override | ||
| public ActionRequestValidationException validate() { | ||
| return null; | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| super.writeTo(out); | ||
| out.writeString(inferenceEntityId); | ||
| taskType.writeTo(out); | ||
| out.writeBytesReference(content); | ||
| XContentHelper.writeTo(out, contentType); | ||
| out.writeTimeValue(timeout); | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (this == o) return true; | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| Request request = (Request) o; | ||
| return taskType == request.taskType | ||
| && Objects.equals(inferenceEntityId, request.inferenceEntityId) | ||
| && Objects.equals(content, request.content) | ||
| && contentType == request.contentType | ||
| && timeout == request.timeout | ||
| && stream == request.stream; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream); | ||
| } | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
...src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.action; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.action.support.ActionFilters; | ||
| import org.elasticsearch.action.support.HandledTransportAction; | ||
| import org.elasticsearch.client.internal.Client; | ||
| import org.elasticsearch.common.util.concurrent.EsExecutors; | ||
| import org.elasticsearch.common.xcontent.XContentHelper; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.inference.UnparsedModel; | ||
| import org.elasticsearch.injection.guice.Inject; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.tasks.Task; | ||
| import org.elasticsearch.transport.TransportService; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xpack.core.inference.action.InferenceAction; | ||
| import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; | ||
| import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; | ||
| import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; | ||
| import org.elasticsearch.xpack.inference.registry.ModelRegistry; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; | ||
| import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; | ||
|
|
||
| public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> { | ||
| private final ModelRegistry modelRegistry; | ||
| private final Client client; | ||
|
|
||
| @Inject | ||
| public TransportInferenceActionProxy( | ||
| TransportService transportService, | ||
| ActionFilters actionFilters, | ||
| ModelRegistry modelRegistry, | ||
| Client client | ||
| ) { | ||
| super( | ||
| InferenceActionProxy.NAME, | ||
| transportService, | ||
| actionFilters, | ||
| InferenceActionProxy.Request::new, | ||
| EsExecutors.DIRECT_EXECUTOR_SERVICE | ||
| ); | ||
|
|
||
| this.modelRegistry = modelRegistry; | ||
| this.client = client; | ||
| } | ||
|
|
||
| @Override | ||
| protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) { | ||
| try { | ||
| ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> { | ||
| if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) { | ||
| sendUnifiedCompletionRequest(request, l); | ||
| } else { | ||
| sendInferenceActionRequest(request, l); | ||
| } | ||
| }); | ||
|
|
||
| if (request.getTaskType() == TaskType.ANY) { | ||
| modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); | ||
| } else if (request.getTaskType() == TaskType.CHAT_COMPLETION) { | ||
| sendUnifiedCompletionRequest(request, listener); | ||
| } else { | ||
| sendInferenceActionRequest(request, listener); | ||
| } | ||
| } catch (Exception e) { | ||
| listener.onFailure(e); | ||
| } | ||
| } | ||
|
|
||
| private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) { | ||
| // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException | ||
| var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
|
|
||
| try { | ||
| if (request.isStreaming() == false) { | ||
| throw new ElasticsearchStatusException( | ||
| "The [chat_completion] task type only supports streaming, please try again with the _stream API", | ||
| RestStatus.BAD_REQUEST | ||
| ); | ||
| } | ||
|
|
||
| UnifiedCompletionAction.Request unifiedRequest; | ||
| try ( | ||
| var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType()) | ||
| ) { | ||
| unifiedRequest = UnifiedCompletionAction.Request.parseRequest( | ||
| request.getInferenceEntityId(), | ||
| request.getTaskType(), | ||
| request.getTimeout(), | ||
| parser | ||
| ); | ||
| } | ||
|
|
||
| executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); | ||
| } catch (Exception e) { | ||
| unifiedErrorFormatListener.onFailure(e); | ||
| } | ||
| } | ||
|
|
||
| private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) | ||
| throws IOException { | ||
| InferenceAction.Request.Builder inferenceActionRequestBuilder; | ||
| try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { | ||
| inferenceActionRequestBuilder = InferenceAction.Request.parseRequest( | ||
| request.getInferenceEntityId(), | ||
| request.getTaskType(), | ||
| parser | ||
| ); | ||
| inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); | ||
| } | ||
|
|
||
| executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davidkyle just wanted to confirm that this is what we want to do here right? Changing it to internal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ yes internal is good