-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Inference API removing _unified and using _stream instead #121804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
e99a781
0c50d84
7d10123
a044c3f
ab5cc0a
456ea68
4df20b5
0d4b801
246b646
edf0d22
bd7f3d6
65cb320
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.core.inference.action; | ||
|
|
||
| import org.elasticsearch.action.ActionRequest; | ||
| import org.elasticsearch.action.ActionRequestValidationException; | ||
| import org.elasticsearch.action.ActionType; | ||
| import org.elasticsearch.common.bytes.BytesReference; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.common.xcontent.XContentHelper; | ||
| import org.elasticsearch.core.TimeValue; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Objects; | ||
|
|
||
| /** | ||
| * This action is used when making a REST request to the inference API. The transport handler | ||
| * will then look at the task type in the params (or retrieve it from the persisted model if it wasn't | ||
| * included in the params) to determine where this request should be routed. If the task type is chat completion | ||
| * then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}. | ||
| * If not, it will be passed along to {@link InferenceAction}. | ||
| */ | ||
| public class InferenceActionProxy extends ActionType<InferenceAction.Response> { | ||
| public static final InferenceActionProxy INSTANCE = new InferenceActionProxy(); | ||
| public static final String NAME = "cluster:monitor/xpack/inference"; | ||
|
||
|
|
||
| public InferenceActionProxy() { | ||
| super(NAME); | ||
| } | ||
|
|
||
| public static class Request extends ActionRequest { | ||
|
|
||
| private final TaskType taskType; | ||
| private final String inferenceEntityId; | ||
| private final BytesReference content; | ||
| private final XContentType contentType; | ||
| private final TimeValue timeout; | ||
| private final boolean stream; | ||
|
|
||
| public Request( | ||
| TaskType taskType, | ||
| String inferenceEntityId, | ||
| BytesReference content, | ||
| XContentType contentType, | ||
| TimeValue timeout, | ||
| boolean stream | ||
| ) { | ||
| this.taskType = taskType; | ||
| this.inferenceEntityId = inferenceEntityId; | ||
| this.content = content; | ||
| this.contentType = contentType; | ||
| this.timeout = timeout; | ||
| this.stream = stream; | ||
| } | ||
|
|
||
| public Request(StreamInput in) throws IOException { | ||
| super(in); | ||
| this.taskType = TaskType.fromStream(in); | ||
| this.inferenceEntityId = in.readString(); | ||
| this.content = in.readBytesReference(); | ||
| this.contentType = in.readEnum(XContentType.class); | ||
| this.timeout = in.readTimeValue(); | ||
|
|
||
| // streaming is not supported yet for transport traffic | ||
| this.stream = false; | ||
| } | ||
|
|
||
| public TaskType getTaskType() { | ||
| return taskType; | ||
| } | ||
|
|
||
| public String getInferenceEntityId() { | ||
| return inferenceEntityId; | ||
| } | ||
|
|
||
| public BytesReference getContent() { | ||
| return content; | ||
| } | ||
|
|
||
| public XContentType getContentType() { | ||
| return contentType; | ||
| } | ||
|
|
||
| public TimeValue getTimeout() { | ||
| return timeout; | ||
| } | ||
|
|
||
| public boolean isStreaming() { | ||
| return stream; | ||
| } | ||
|
|
||
| @Override | ||
| public ActionRequestValidationException validate() { | ||
| // TODO confirm that we don't need any validation | ||
| return null; | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| super.writeTo(out); | ||
| out.writeString(inferenceEntityId); | ||
| taskType.writeTo(out); | ||
| out.writeBytesReference(content); | ||
| XContentHelper.writeTo(out, contentType); | ||
| out.writeTimeValue(timeout); | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (this == o) return true; | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| Request request = (Request) o; | ||
| return taskType == request.taskType | ||
| && Objects.equals(inferenceEntityId, request.inferenceEntityId) | ||
| && Objects.equals(content, request.content) | ||
| && contentType == request.contentType; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(taskType, inferenceEntityId, content, contentType); | ||
jonathan-buttner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.action; | ||
|
|
||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.action.support.ActionFilters; | ||
| import org.elasticsearch.action.support.HandledTransportAction; | ||
| import org.elasticsearch.client.internal.Client; | ||
| import org.elasticsearch.common.util.concurrent.EsExecutors; | ||
| import org.elasticsearch.common.xcontent.XContentHelper; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.inference.UnparsedModel; | ||
| import org.elasticsearch.injection.guice.Inject; | ||
| import org.elasticsearch.rest.RestStatus; | ||
| import org.elasticsearch.tasks.Task; | ||
| import org.elasticsearch.transport.TransportService; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xpack.core.inference.action.InferenceAction; | ||
| import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; | ||
| import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; | ||
| import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; | ||
| import org.elasticsearch.xpack.inference.registry.ModelRegistry; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; | ||
| import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; | ||
|
|
||
| public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> { | ||
| private final ModelRegistry modelRegistry; | ||
| private final Client client; | ||
|
|
||
| @Inject | ||
| public TransportInferenceActionProxy( | ||
| TransportService transportService, | ||
| ActionFilters actionFilters, | ||
| ModelRegistry modelRegistry, | ||
| Client client | ||
| ) { | ||
| super( | ||
| InferenceActionProxy.NAME, | ||
| transportService, | ||
| actionFilters, | ||
| InferenceActionProxy.Request::new, | ||
| EsExecutors.DIRECT_EXECUTOR_SERVICE | ||
| ); | ||
|
|
||
| this.modelRegistry = modelRegistry; | ||
| this.client = client; | ||
| } | ||
|
|
||
| @Override | ||
| protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) { | ||
| try { | ||
| ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> { | ||
| if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) { | ||
| sendUnifiedCompletionRequest(request, l); | ||
| } else { | ||
| sendInferenceActionRequest(request, l); | ||
| } | ||
| }); | ||
|
|
||
| if (request.getTaskType() == TaskType.ANY) { | ||
| modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); | ||
| } else if (request.getTaskType() == TaskType.CHAT_COMPLETION) { | ||
| sendUnifiedCompletionRequest(request, listener); | ||
| } else { | ||
| sendInferenceActionRequest(request, listener); | ||
| } | ||
| } catch (Exception e) { | ||
| listener.onFailure(e); | ||
| } | ||
| } | ||
|
|
||
| private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) { | ||
| // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException | ||
| var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
|
|
||
| try { | ||
| if (request.isStreaming() == false) { | ||
| throw new ElasticsearchStatusException( | ||
| "The [chat_completion] task type only supports streaming, please try again with the _stream API", | ||
| RestStatus.BAD_REQUEST | ||
| ); | ||
| } | ||
|
|
||
| UnifiedCompletionAction.Request unifiedRequest; | ||
| try ( | ||
| var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType()) | ||
| ) { | ||
| unifiedRequest = UnifiedCompletionAction.Request.parseRequest( | ||
| request.getInferenceEntityId(), | ||
| request.getTaskType(), | ||
| request.getTimeout(), | ||
| parser | ||
| ); | ||
| } | ||
|
|
||
| executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); | ||
| } catch (Exception e) { | ||
| unifiedErrorFormatListener.onFailure(e); | ||
| } | ||
| } | ||
|
|
||
| private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) | ||
| throws IOException { | ||
| InferenceAction.Request.Builder inferenceActionRequestBuilder; | ||
| try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { | ||
| inferenceActionRequestBuilder = InferenceAction.Request.parseRequest( | ||
| request.getInferenceEntityId(), | ||
| request.getTaskType(), | ||
| parser | ||
| ); | ||
| inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); | ||
| } | ||
|
|
||
| executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davidkyle just wanted to confirm that this is what we want to do here right? Changing it to internal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ yes internal is good