From e99a781cde456eab5f0febfc0e4b6108ecad94b7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 5 Feb 2025 13:40:01 -0500 Subject: [PATCH 1/7] Adding proxy action --- .../inference/action/InferenceAction.java | 2 +- .../action/InferenceActionProxy.java | 132 ++++++++++++ .../action/UnifiedCompletionAction.java | 2 +- .../authz/store/ReservedRolesStoreTests.java | 1 - .../xpack/inference/InferencePlugin.java | 7 +- .../action/TransportInferenceActionProxy.java | 117 +++++++++++ .../inference/rest/BaseInferenceAction.java | 24 ++- .../xpack/inference/rest/Paths.java | 15 +- .../inference/rest/RestInferenceAction.java | 5 + .../rest/RestStreamInferenceAction.java | 44 +--- .../RestUnifiedCompletionInferenceAction.java | 67 ------ .../inference/services/ServiceUtils.java | 4 +- .../validation/ModelValidatorBuilder.java | 2 +- .../TransportInferenceActionProxyTests.java | 191 ++++++++++++++++++ .../rest/BaseInferenceActionTests.java | 18 +- ...ava => RestInferenceActionProxyTests.java} | 8 +- .../rest/RestStreamInferenceActionTests.java | 57 +++++- ...UnifiedCompletionInferenceActionTests.java | 91 --------- .../elastic/ElasticInferenceServiceTests.java | 2 +- .../services/openai/OpenAiServiceTests.java | 2 +- .../xpack/security/operator/Constants.java | 3 +- 21 files changed, 552 insertions(+), 242 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/{RestInferenceActionTests.java => RestInferenceActionProxyTests.java} (90%) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index f88909ba4208e..f2b2c563d7519 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -47,7 +47,7 @@ public class InferenceAction extends ActionType { public static final InferenceAction INSTANCE = new InferenceAction(); - public static final String NAME = "cluster:monitor/xpack/inference"; + public static final String NAME = "cluster:internal/xpack/inference"; public InferenceAction() { super(NAME); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java new file mode 100644 index 0000000000000..4e5898927d1ea --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.Objects; + +/** + * This action is used when making a REST request to the inference API. The transport handler + * will then look at the task type in the params (or retrieve it from the persisted model if it wasn't + * included in the params) to determine where this request should be routed. If the task type is chat completion + * then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}. + * If not, it will be passed along to {@link InferenceAction}. + */ +public class InferenceActionProxy extends ActionType { + public static final InferenceActionProxy INSTANCE = new InferenceActionProxy(); + public static final String NAME = "cluster:monitor/xpack/inference"; + + public InferenceActionProxy() { + super(NAME); + } + + public static class Request extends ActionRequest { + + private final TaskType taskType; + private final String inferenceEntityId; + private final BytesReference content; + private final XContentType contentType; + private final TimeValue timeout; + private final boolean stream; + + public Request( + TaskType taskType, + String inferenceEntityId, + BytesReference content, + XContentType contentType, + TimeValue timeout, + boolean stream + ) { + this.taskType = taskType; + this.inferenceEntityId = inferenceEntityId; + this.content = content; + this.contentType = contentType; + this.timeout = timeout; + this.stream = stream; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.taskType = TaskType.fromStream(in); + this.inferenceEntityId = in.readString(); + this.content = in.readBytesReference(); + this.contentType = in.readEnum(XContentType.class); + this.timeout = in.readTimeValue(); + + // streaming is not supported yet for transport traffic + this.stream = false; + } + + public TaskType getTaskType() { + return taskType; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public BytesReference getContent() { + return content; + } + + public XContentType getContentType() { + return contentType; + } + + public TimeValue getTimeout() { + return timeout; + } + + public boolean isStreaming() { + return stream; + } + + @Override + public ActionRequestValidationException validate() { + // TODO confirm that we don't need any validation + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + out.writeBytesReference(content); + XContentHelper.writeTo(out, contentType); + out.writeTimeValue(timeout); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return taskType == request.taskType + && Objects.equals(inferenceEntityId, request.inferenceEntityId) + && Objects.equals(content, request.content) + && contentType == request.contentType; + } + + @Override + public int hashCode() { + return Objects.hash(taskType, inferenceEntityId, content, contentType); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index f5c852a0450ae..43c84ad914c2a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -21,7 +21,7 @@ public class UnifiedCompletionAction extends ActionType { public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction(); - public static final String NAME = "cluster:monitor/xpack/inference/unified"; + public static final String NAME = "cluster:internal/xpack/inference/unified"; public UnifiedCompletionAction() { super(NAME); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index 107953557f3ea..4dbdaf2931495 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -4163,7 +4163,6 @@ public void testInferenceUserRole() { assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); - assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 23df62caab430..e3604351c1937 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -58,6 +58,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; @@ -67,6 +68,7 @@ import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction; import org.elasticsearch.xpack.inference.action.TransportInferenceAction; +import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; @@ -104,7 +106,6 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; -import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; @@ -195,6 +196,7 @@ public InferencePlugin(Settings settings) { public List> getActions() { return List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), + new ActionHandler<>(InferenceActionProxy.INSTANCE, TransportInferenceActionProxy.class), new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), @@ -226,8 +228,7 @@ public List getRestHandlers( new RestUpdateInferenceModelAction(), new RestDeleteInferenceEndpointAction(), new RestGetInferenceDiagnosticsAction(), - new RestGetInferenceServicesAction(), - new RestUnifiedCompletionInferenceAction(threadPoolSetOnce) + new RestGetInferenceServicesAction() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java new file mode 100644 index 0000000000000..2102239d29d3e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.io.IOException; + +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TransportInferenceActionProxy extends HandledTransportAction { + private final ModelRegistry modelRegistry; + private final Client client; + + @Inject + public TransportInferenceActionProxy( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + Client client + ) { + super( + InferenceActionProxy.NAME, + transportService, + actionFilters, + InferenceActionProxy.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + this.modelRegistry = modelRegistry; + this.client = client; + } + + @Override + protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener listener) { + try { + ActionListener getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> { + if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) { + sendUnifiedCompletionRequest(request, l); + } else { + sendInferenceActionRequest(request, l); + } + }); + + if (request.getTaskType() == TaskType.ANY) { + modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); + } else if (request.getTaskType() == TaskType.CHAT_COMPLETION) { + sendUnifiedCompletionRequest(request, listener); + } else { + sendInferenceActionRequest(request, listener); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener listener) + throws IOException { + + if (request.isStreaming() == false) { + throw new ElasticsearchStatusException( + "The [chat_completion] task type only supports streaming, please try again with the _stream API", + RestStatus.BAD_REQUEST + ); + } + + UnifiedCompletionAction.Request unifiedRequest; + try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { + unifiedRequest = UnifiedCompletionAction.Request.parseRequest( + request.getInferenceEntityId(), + request.getTaskType(), + request.getTimeout(), + parser + ); + } + + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); + } + + private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener listener) + throws IOException { + InferenceAction.Request.Builder inferenceActionRequestBuilder; + try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { + inferenceActionRequestBuilder = InferenceAction.Request.parseRequest( + request.getInferenceEntityId(), + request.getTaskType(), + parser + ); + inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); + } + + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index d911158e82296..569156fb28020 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import java.io.IOException; @@ -41,21 +42,22 @@ static TimeValue parseTimeout(RestRequest restRequest) { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { var params = parseParams(restRequest); + var content = restRequest.requiredContent(); + var inferTimeout = parseTimeout(restRequest); - InferenceAction.Request.Builder requestBuilder; - try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); - } + var request = new InferenceActionProxy.Request( + params.taskType(), + params.inferenceEntityId(), + content, + restRequest.getXContentType(), + inferTimeout, + shouldStream() + ); - var inferTimeout = parseTimeout(restRequest); - requestBuilder.setInferenceTimeout(inferTimeout); - var request = prepareInferenceRequest(requestBuilder); - return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); + return channel -> client.execute(InferenceActionProxy.INSTANCE, request, listener(channel)); } - protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { - return builder.build(); - } + protected abstract boolean shouldStream(); protected abstract ActionListener listener(RestChannel channel); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 7f43676dfb5f0..d81df04eb9eef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -24,22 +24,13 @@ public final class Paths { static final String INFERENCE_SERVICES_PATH = "_inference/_services"; static final String TASK_TYPE_INFERENCE_SERVICES_PATH = "_inference/_services/{" + TASK_TYPE + "}"; - static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_stream"; + public static final String STREAM_SUFFIX = "_stream"; + static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + STREAM_SUFFIX; static final String STREAM_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID - + "}/_stream"; - - // TODO remove the _unified path - public static final String UNIFIED_SUFFIX = "_unified"; - static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX; - static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" - + TASK_TYPE_OR_INFERENCE_ID - + "}/{" - + INFERENCE_ID - + "}/" - + UNIFIED_SUFFIX; + + "}/" + STREAM_SUFFIX; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index 0fbc2f8214cbb..55083dcd4c888 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -32,6 +32,11 @@ public List routes() { return List.of(new Route(POST, INFERENCE_ID_PATH), new Route(POST, TASK_TYPE_INFERENCE_ID_PATH)); } + @Override + protected boolean shouldStream() { + return false; + } + @Override protected ActionListener listener(RestChannel channel) { return new RestChunkedToXContentListener<>(channel); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 518056365d88b..9fbedcda889bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -9,17 +9,13 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; -import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import java.io.IOException; import java.util.List; import java.util.Objects; @@ -46,41 +42,17 @@ public List routes() { return List.of(new Route(POST, STREAM_INFERENCE_ID_PATH), new Route(POST, STREAM_TASK_TYPE_INFERENCE_ID_PATH)); } - @Override - protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { - return builder.setStream(true).build(); - } - @Override protected ActionListener listener(RestChannel channel) { - return new ServerSentEventsRestActionListener(channel, threadPool); + // TODO confirm with Pat that this will work + return new ServerSentEventsRestActionListener(channel, threadPool).delegateResponse((l, e) -> { + // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException + l.onFailure(UnifiedChatCompletionException.fromThrowable(e)); + }); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - var params = parseParams(restRequest); - var inferTimeout = parseTimeout(restRequest); - - if (params.taskType() == TaskType.CHAT_COMPLETION) { - UnifiedCompletionAction.Request request; - try (var parser = restRequest.contentParser()) { - request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); - } - - return channel -> client.execute( - UnifiedCompletionAction.INSTANCE, - request, - new ServerSentEventsRestActionListener(channel, threadPool) - ); - } else { - InferenceAction.Request.Builder requestBuilder; - try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); - } - - requestBuilder.setInferenceTimeout(inferTimeout); - var request = prepareInferenceRequest(requestBuilder); - return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); - } + protected boolean shouldStream() { + return true; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java deleted file mode 100644 index 0efd31a6832cc..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.rest; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.Scope; -import org.elasticsearch.rest.ServerlessScope; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.rest.RestRequest.Method.POST; -import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; -import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; - -@ServerlessScope(Scope.PUBLIC) -public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { - private final SetOnce threadPool; - - public RestUnifiedCompletionInferenceAction(SetOnce threadPool) { - super(); - this.threadPool = Objects.requireNonNull(threadPool); - } - - @Override - public String getName() { - return "unified_inference_action"; - } - - @Override - public List routes() { - return List.of(new Route(POST, UNIFIED_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); - } - - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - var params = BaseInferenceAction.parseParams(restRequest); - - var inferTimeout = BaseInferenceAction.parseTimeout(restRequest); - - UnifiedCompletionAction.Request request; - try (var parser = restRequest.contentParser()) { - request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); - } - - return channel -> client.execute( - UnifiedCompletionAction.INSTANCE, - request, - new ServerSentEventsRestActionListener(channel, threadPool).delegateResponse((l, e) -> { - // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException - l.onFailure(UnifiedChatCompletionException.fromThrowable(e)); - }) - ); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1ddae3cc8df95..13d641101a1cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -42,7 +42,7 @@ import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED; import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS; -import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_SUFFIX; +import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_SUFFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; public final class ServiceUtils { @@ -796,7 +796,7 @@ public static String useChatCompletionUrlMessage(Model model) { model.getTaskType(), model.getTaskType(), model.getInferenceEntityId(), - UNIFIED_SUFFIX + STREAM_SUFFIX ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index e4de3d6beb800..1c4306c4edd46 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -29,7 +29,7 @@ public static ModelValidator buildModelValidator(TaskType taskType) { case SPARSE_EMBEDDING, RERANK, ANY -> { return new SimpleModelValidator(new SimpleServiceIntegrationValidator()); } - default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model of for task type %s ", taskType)); + default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model for task type %s", taskType)); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java new file mode 100644 index 0000000000000..a9e6ec55a6224 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; + +import java.util.Collections; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportInferenceActionProxyTests extends ESTestCase { + private Client client; + private ThreadPool threadPool; + private TransportInferenceActionProxy action; + private ModelRegistry modelRegistry; + + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + threadPool = new TestThreadPool("test"); + when(client.threadPool()).thenReturn(threadPool); + modelRegistry = mock(ModelRegistry.class); + + action = new TransportInferenceActionProxy(mock(TransportService.class), mock(ActionFilters.class), modelRegistry, client); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_InRequest() { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ] + } + ] + } + """; + + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) mock(ActionListener.class); + var request = new InferenceActionProxy.Request( + TaskType.CHAT_COMPLETION, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(UnifiedCompletionAction.INSTANCE), any(), any()); + } + + public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_FromStorage() { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ] + } + ] + } + """; + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse( + new UnparsedModel("id", TaskType.CHAT_COMPLETION, "service", Collections.emptyMap(), Collections.emptyMap()) + ); + + return Void.TYPE; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = new PlainActionFuture(); + var request = new InferenceActionProxy.Request( + TaskType.ANY, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(UnifiedCompletionAction.INSTANCE), any(), any()); + } + + public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_InRequest() { + String requestJson = """ + { + "input": ["some text"] + } + """; + + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) mock(ActionListener.class); + var request = new InferenceActionProxy.Request( + TaskType.COMPLETION, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(InferenceAction.INSTANCE), any(), any()); + } + + public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_FromStorage() { + String requestJson = """ + { + "input": ["some text"] + } + """; + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new UnparsedModel("id", TaskType.COMPLETION, "service", Collections.emptyMap(), Collections.emptyMap())); + + return Void.TYPE; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = new PlainActionFuture(); + var request = new InferenceActionProxy.Request( + TaskType.ANY, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(InferenceAction.INSTANCE), any(), any()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 9e6718172a2f0..61d28fae402ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.junit.Before; @@ -43,6 +44,11 @@ public class BaseInferenceActionTests extends RestActionTestCase { @Before public void setUpAction() { controller().registerHandler(new BaseInferenceAction() { + @Override + protected boolean shouldStream() { + return false; + } + @Override protected ActionListener listener(RestChannel channel) { return new RestChunkedToXContentListener<>(channel); @@ -103,10 +109,10 @@ public void testParseTimeout_ReturnsDefaultTimeout() { public void testUsesDefaultTimeout() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); + var request = (InferenceActionProxy.Request) actionRequest; + assertThat(request.getTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); executeCalled.set(true); return createResponse(); @@ -123,10 +129,10 @@ public void testUsesDefaultTimeout() { public void testUses3SecondTimeoutFromParams() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3))); + var request = (InferenceActionProxy.Request) actionRequest; + assertThat(request.getTimeout(), is(TimeValue.timeValueSeconds(3))); executeCalled.set(true); return createResponse(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionProxyTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionProxyTests.java index 1b0df1b4a20da..433e33fe15210 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionProxyTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -21,7 +21,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -public class RestInferenceActionTests extends RestActionTestCase { +public class RestInferenceActionProxyTests extends RestActionTestCase { @Before public void setUpAction() { @@ -31,9 +31,9 @@ public void setUpAction() { public void testStreamIsFalse() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; + var request = (InferenceActionProxy.Request) actionRequest; assertThat(request.isStreaming(), is(false)); executeCalled.set(true); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java index f67680ef6b625..e69dd3fda6240 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -9,13 +9,18 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.junit.After; import org.junit.Before; @@ -42,9 +47,9 @@ public void tearDownAction() { public void testStreamIsTrue() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; + var request = (InferenceActionProxy.Request) actionRequest; assertThat(request.isStreaming(), is(true)); executeCalled.set(true); @@ -58,4 +63,50 @@ public void testStreamIsTrue() { dispatchRequest(inferenceRequest); assertThat(executeCalled.get(), equalTo(true)); } + + public void testStreamIsTrue_ChatCompletion() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); + + var request = (InferenceActionProxy.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + var requestBody = """ + { + "messages": [ + { + "content": "abc", + "role": "user" + } + ] + } + """; + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/chat_completion/test/_stream") + .withContent(new BytesArray(requestBody), XContentType.JSON) + .build(); + + final SetOnce responseSetOnce = new SetOnce<>(); + dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { + @Override + public void sendResponse(RestResponse response) { + responseSetOnce.set(response); + } + }); + + // the response content will be null when there is no error + assertNull(responseSetOnce.get().content()); + assertThat(executeCalled.get(), equalTo(true)); + } + + private void dispatchRequest(final RestRequest request, final RestChannel channel) { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + controller().dispatchRequest(request, channel, threadContext); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java deleted file mode 100644 index 9dc23c890c14d..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.rest; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestChannel; -import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.test.rest.FakeRestRequest; -import org.elasticsearch.test.rest.RestActionTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; -import org.junit.After; -import org.junit.Before; - -import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; - -public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { - private final SetOnce threadPool = new SetOnce<>(); - - @Before - public void setUpAction() { - threadPool.set(new TestThreadPool(getTestName())); - controller().registerHandler(new RestUnifiedCompletionInferenceAction(threadPool)); - } - - @After - public void tearDownAction() { - terminate(threadPool.get()); - } - - public void testStreamIsTrue() { - SetOnce executeCalled = new SetOnce<>(); - verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class)); - - var request = (UnifiedCompletionAction.Request) actionRequest; - assertThat(request.isStreaming(), is(true)); - - executeCalled.set(true); - return createResponse(); - })); - - var requestBody = """ - { - "messages": [ - { - "content": "abc", - "role": "user" - } - ] - } - """; - - RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath("_inference/completion/test/_unified") - .withContent(new BytesArray(requestBody), XContentType.JSON) - .build(); - - final SetOnce responseSetOnce = new SetOnce<>(); - dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { - @Override - public void sendResponse(RestResponse response) { - responseSetOnce.set(response); - } - }); - - // the response content will be null when there is no error - assertNull(responseSetOnce.get().content()); - assertThat(executeCalled.get(), equalTo(true)); - } - - private void dispatchRequest(final RestRequest request, final RestChannel channel) { - ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - controller().dispatchRequest(request, channel, threadContext); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 5d66486731f5e..743a3fb666ecd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -436,7 +436,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws "Inference entity [model_id] does not support task type [chat_completion] " + "for inference, the task type must be one of [sparse_embedding]. " + "The task type for the inference entity is chat_completion, " - + "please use the _inference/chat_completion/model_id/_unified URL." + + "please use the _inference/chat_completion/model_id/_stream URL." ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 34539042c1f0b..687f5430904e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -937,7 +937,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws "Inference entity [model_id] does not support task type [chat_completion] " + "for inference, the task type must be one of [text_embedding, completion]. " + "The task type for the inference entity is chat_completion, " - + "please use the _inference/chat_completion/model_id/_unified URL." + + "please use the _inference/chat_completion/model_id/_stream URL." ) ); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 5d3b13b9d451a..59310159b402d 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -327,6 +327,8 @@ public class Constants { "cluster:admin/xpack/watcher/settings/update", "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", + "cluster:internal/xpack/inference", + "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/ml/coordinatedinference", "cluster:internal/xpack/ml/datafeed/isolate", "cluster:internal/xpack/ml/datafeed/running_state", @@ -388,7 +390,6 @@ public class Constants { "cluster:monitor/xpack/esql/stats/dist", "cluster:monitor/xpack/inference", "cluster:monitor/xpack/inference/get", - "cluster:monitor/xpack/inference/unified", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get", "cluster:monitor/xpack/info", From 0c50d84d56d7a3289c263f840546820a8d9de78d Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 5 Feb 2025 18:49:31 +0000 Subject: [PATCH 2/7] [CI] Auto commit changes from spotless --- .../java/org/elasticsearch/xpack/inference/rest/Paths.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index d81df04eb9eef..b1edec79dfb72 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -30,7 +30,8 @@ public final class Paths { + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID - + "}/" + STREAM_SUFFIX; + + "}/" + + STREAM_SUFFIX; private Paths() { From 7d1012397ca22c7bf1d02cd914a25218e610e6a0 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 5 Feb 2025 15:59:46 -0500 Subject: [PATCH 3/7] Incrementing reference count for body content and fixing tests --- .../inference/InferenceBaseRestTest.java | 3 +- .../action/TransportInferenceActionProxy.java | 43 +++++++++++-------- .../inference/rest/BaseInferenceAction.java | 2 +- .../rest/RestStreamInferenceAction.java | 7 +-- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 950ff196e5136..36a4b95a7ca23 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -360,8 +360,7 @@ protected Deque unifiedCompletionInferOnMockService( List input, @Nullable Consumer responseConsumerCallback ) throws Exception { - var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route - var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route); + var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); return callAsyncUnified(endpoint, input, "user", responseConsumerCallback); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index 2102239d29d3e..eee9778746560 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; @@ -77,27 +78,33 @@ protected void doExecute(Task task, InferenceActionProxy.Request request, Action } } - private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener listener) - throws IOException { + private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener listener) { + // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException + var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))); - if (request.isStreaming() == false) { - throw new ElasticsearchStatusException( - "The [chat_completion] task type only supports streaming, please try again with the _stream API", - RestStatus.BAD_REQUEST - ); - } + try { + if (request.isStreaming() == false) { + unifiedErrorFormatListener.onFailure(new ElasticsearchStatusException( + "The [chat_completion] task type only supports streaming, please try again with the _stream API", + RestStatus.BAD_REQUEST + )); + return; + } - UnifiedCompletionAction.Request unifiedRequest; - try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { - unifiedRequest = UnifiedCompletionAction.Request.parseRequest( - request.getInferenceEntityId(), - request.getTaskType(), - request.getTimeout(), - parser - ); - } + UnifiedCompletionAction.Request unifiedRequest; + try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { + unifiedRequest = UnifiedCompletionAction.Request.parseRequest( + request.getInferenceEntityId(), + request.getTaskType(), + request.getTimeout(), + parser + ); + } - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); + } catch (Exception e) { + unifiedErrorFormatListener.onFailure(e); + } } private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener listener) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index 569156fb28020..06a0849b91d4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -54,7 +54,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient shouldStream() ); - return channel -> client.execute(InferenceActionProxy.INSTANCE, request, listener(channel)); + return channel -> client.execute(InferenceActionProxy.INSTANCE, request, ActionListener.withRef(listener(channel), content)); } protected abstract boolean shouldStream(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 9fbedcda889bf..f37f4e9fb1f9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -14,7 +14,6 @@ import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import java.util.List; import java.util.Objects; @@ -44,11 +43,7 @@ public List routes() { @Override protected ActionListener listener(RestChannel channel) { - // TODO confirm with Pat that this will work - return new ServerSentEventsRestActionListener(channel, threadPool).delegateResponse((l, e) -> { - // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException - l.onFailure(UnifiedChatCompletionException.fromThrowable(e)); - }); + return new ServerSentEventsRestActionListener(channel, threadPool); } @Override From ab5cc0a4a76d60d025c775ceee662d2654b07bfe Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 5 Feb 2025 21:05:51 +0000 Subject: [PATCH 4/7] [CI] Auto commit changes from spotless --- .../action/TransportInferenceActionProxy.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index eee9778746560..fa84f1d58a9fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -84,15 +84,19 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, try { if (request.isStreaming() == false) { - unifiedErrorFormatListener.onFailure(new ElasticsearchStatusException( - "The [chat_completion] task type only supports streaming, please try again with the _stream API", - RestStatus.BAD_REQUEST - )); + unifiedErrorFormatListener.onFailure( + new ElasticsearchStatusException( + "The [chat_completion] task type only supports streaming, please try again with the _stream API", + RestStatus.BAD_REQUEST + ) + ); return; } UnifiedCompletionAction.Request unifiedRequest; - try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { + try ( + var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType()) + ) { unifiedRequest = UnifiedCompletionAction.Request.parseRequest( request.getInferenceEntityId(), request.getTaskType(), From 456ea680b543fec1ee8b1aee1e7ff453a31935ac Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 6 Feb 2025 08:16:55 -0500 Subject: [PATCH 5/7] Refactoring --- .../inference/action/TransportInferenceActionProxy.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index eee9778746560..b38996df3861c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -84,11 +84,10 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, try { if (request.isStreaming() == false) { - unifiedErrorFormatListener.onFailure(new ElasticsearchStatusException( + throw new ElasticsearchStatusException( "The [chat_completion] task type only supports streaming, please try again with the _stream API", RestStatus.BAD_REQUEST - )); - return; + ); } UnifiedCompletionAction.Request unifiedRequest; From 246b646efa47e1f660fe24fedd18de4d362e7cfb Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:12:33 -0500 Subject: [PATCH 6/7] Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java Co-authored-by: David Kyle --- .../xpack/core/inference/action/InferenceActionProxy.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java index 4e5898927d1ea..28ebca6375f6d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java @@ -126,7 +126,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, content, contentType); + return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream); } } } From edf0d2259cd405907b4236a685ff229a4c7e0087 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 6 Feb 2025 16:40:09 -0500 Subject: [PATCH 7/7] Addressing feedback --- .../xpack/core/inference/action/InferenceActionProxy.java | 7 ++++--- .../core/security/authz/store/ReservedRolesStoreTests.java | 4 ++-- .../elasticsearch/xpack/security/operator/Constants.java | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java index 28ebca6375f6d..68cd39f26b456 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java @@ -30,7 +30,7 @@ */ public class InferenceActionProxy extends ActionType { public static final InferenceActionProxy INSTANCE = new InferenceActionProxy(); - public static final String NAME = "cluster:monitor/xpack/inference"; + public static final String NAME = "cluster:monitor/xpack/inference/post"; public InferenceActionProxy() { super(NAME); @@ -99,7 +99,6 @@ public boolean isStreaming() { @Override public ActionRequestValidationException validate() { - // TODO confirm that we don't need any validation return null; } @@ -121,7 +120,9 @@ public boolean equals(Object o) { return taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) && Objects.equals(content, request.content) - && contentType == request.contentType; + && contentType == request.contentType + && timeout == request.timeout + && stream == request.stream; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index 4dbdaf2931495..ab5a1bbd8c897 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -4140,7 +4140,7 @@ public void testInferenceAdminRole() { assertThat(roleDescriptor.getMetadata(), hasEntry("_reserved", true)); Role role = Role.buildFromRoleDescriptor(roleDescriptor, new FieldPermissionsCache(Settings.EMPTY), RESTRICTED_INDICES); - assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/post", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertTrue(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); assertTrue(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); @@ -4160,7 +4160,7 @@ public void testInferenceUserRole() { assertThat(roleDescriptor.getMetadata(), hasEntry("_reserved", true)); Role role = Role.buildFromRoleDescriptor(roleDescriptor, new FieldPermissionsCache(Settings.EMPTY), RESTRICTED_INDICES); - assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/post", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 59310159b402d..7e571b6db2f92 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -388,7 +388,7 @@ public class Constants { "cluster:monitor/xpack/enrich/stats", "cluster:monitor/xpack/eql/stats/dist", "cluster:monitor/xpack/esql/stats/dist", - "cluster:monitor/xpack/inference", + "cluster:monitor/xpack/inference/post", "cluster:monitor/xpack/inference/get", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get",