diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index f88909ba4208e..f2b2c563d7519 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -47,7 +47,7 @@ public class InferenceAction extends ActionType { public static final InferenceAction INSTANCE = new InferenceAction(); - public static final String NAME = "cluster:monitor/xpack/inference"; + public static final String NAME = "cluster:internal/xpack/inference"; public InferenceAction() { super(NAME); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java new file mode 100644 index 0000000000000..68cd39f26b456 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.Objects; + +/** + * This action is used when making a REST request to the inference API. The transport handler + * will then look at the task type in the params (or retrieve it from the persisted model if it wasn't + * included in the params) to determine where this request should be routed. If the task type is chat completion + * then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}. + * If not, it will be passed along to {@link InferenceAction}. + */ +public class InferenceActionProxy extends ActionType { + public static final InferenceActionProxy INSTANCE = new InferenceActionProxy(); + public static final String NAME = "cluster:monitor/xpack/inference/post"; + + public InferenceActionProxy() { + super(NAME); + } + + public static class Request extends ActionRequest { + + private final TaskType taskType; + private final String inferenceEntityId; + private final BytesReference content; + private final XContentType contentType; + private final TimeValue timeout; + private final boolean stream; + + public Request( + TaskType taskType, + String inferenceEntityId, + BytesReference content, + XContentType contentType, + TimeValue timeout, + boolean stream + ) { + this.taskType = taskType; + this.inferenceEntityId = inferenceEntityId; + this.content = content; + this.contentType = contentType; + this.timeout = timeout; + this.stream = stream; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.taskType = TaskType.fromStream(in); + this.inferenceEntityId = in.readString(); + this.content = in.readBytesReference(); + this.contentType = in.readEnum(XContentType.class); + this.timeout = in.readTimeValue(); + + // streaming is not supported yet for transport traffic + this.stream = false; + } + + public TaskType getTaskType() { + return taskType; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public BytesReference getContent() { + return content; + } + + public XContentType getContentType() { + return contentType; + } + + public TimeValue getTimeout() { + return timeout; + } + + public boolean isStreaming() { + return stream; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + out.writeBytesReference(content); + XContentHelper.writeTo(out, contentType); + out.writeTimeValue(timeout); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return taskType == request.taskType + && Objects.equals(inferenceEntityId, request.inferenceEntityId) + && Objects.equals(content, request.content) + && contentType == request.contentType + && timeout == request.timeout + && stream == request.stream; + } + + @Override + public int hashCode() { + return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index f5c852a0450ae..43c84ad914c2a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -21,7 +21,7 @@ public class UnifiedCompletionAction extends ActionType { public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction(); - public static final String NAME = "cluster:monitor/xpack/inference/unified"; + public static final String NAME = "cluster:internal/xpack/inference/unified"; public UnifiedCompletionAction() { super(NAME); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index 107953557f3ea..ab5a1bbd8c897 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -4140,7 +4140,7 @@ public void testInferenceAdminRole() { assertThat(roleDescriptor.getMetadata(), hasEntry("_reserved", true)); Role role = Role.buildFromRoleDescriptor(roleDescriptor, new FieldPermissionsCache(Settings.EMPTY), RESTRICTED_INDICES); - assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/post", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertTrue(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); assertTrue(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); @@ -4160,10 +4160,9 @@ public void testInferenceUserRole() { assertThat(roleDescriptor.getMetadata(), hasEntry("_reserved", true)); Role role = Role.buildFromRoleDescriptor(roleDescriptor, new FieldPermissionsCache(Settings.EMPTY), RESTRICTED_INDICES); - assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/post", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); - assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication)); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 950ff196e5136..36a4b95a7ca23 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -360,8 +360,7 @@ protected Deque unifiedCompletionInferOnMockService( List input, @Nullable Consumer responseConsumerCallback ) throws Exception { - var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route - var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route); + var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); return callAsyncUnified(endpoint, input, "user", responseConsumerCallback); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 23df62caab430..e3604351c1937 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -58,6 +58,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; @@ -67,6 +68,7 @@ import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction; import org.elasticsearch.xpack.inference.action.TransportInferenceAction; +import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; @@ -104,7 +106,6 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; -import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; @@ -195,6 +196,7 @@ public InferencePlugin(Settings settings) { public List> getActions() { return List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), + new ActionHandler<>(InferenceActionProxy.INSTANCE, TransportInferenceActionProxy.class), new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), @@ -226,8 +228,7 @@ public List getRestHandlers( new RestUpdateInferenceModelAction(), new RestDeleteInferenceEndpointAction(), new RestGetInferenceDiagnosticsAction(), - new RestGetInferenceServicesAction(), - new RestUnifiedCompletionInferenceAction(threadPoolSetOnce) + new RestGetInferenceServicesAction() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java new file mode 100644 index 0000000000000..6d46f834d4873 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.io.IOException; + +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TransportInferenceActionProxy extends HandledTransportAction { + private final ModelRegistry modelRegistry; + private final Client client; + + @Inject + public TransportInferenceActionProxy( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + Client client + ) { + super( + InferenceActionProxy.NAME, + transportService, + actionFilters, + InferenceActionProxy.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + this.modelRegistry = modelRegistry; + this.client = client; + } + + @Override + protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener listener) { + try { + ActionListener getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> { + if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) { + sendUnifiedCompletionRequest(request, l); + } else { + sendInferenceActionRequest(request, l); + } + }); + + if (request.getTaskType() == TaskType.ANY) { + modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); + } else if (request.getTaskType() == TaskType.CHAT_COMPLETION) { + sendUnifiedCompletionRequest(request, listener); + } else { + sendInferenceActionRequest(request, listener); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener listener) { + // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException + var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))); + + try { + if (request.isStreaming() == false) { + throw new ElasticsearchStatusException( + "The [chat_completion] task type only supports streaming, please try again with the _stream API", + RestStatus.BAD_REQUEST + ); + } + + UnifiedCompletionAction.Request unifiedRequest; + try ( + var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType()) + ) { + unifiedRequest = UnifiedCompletionAction.Request.parseRequest( + request.getInferenceEntityId(), + request.getTaskType(), + request.getTimeout(), + parser + ); + } + + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); + } catch (Exception e) { + unifiedErrorFormatListener.onFailure(e); + } + } + + private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener listener) + throws IOException { + InferenceAction.Request.Builder inferenceActionRequestBuilder; + try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { + inferenceActionRequestBuilder = InferenceAction.Request.parseRequest( + request.getInferenceEntityId(), + request.getTaskType(), + parser + ); + inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); + } + + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index d911158e82296..06a0849b91d4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import java.io.IOException; @@ -41,21 +42,22 @@ static TimeValue parseTimeout(RestRequest restRequest) { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { var params = parseParams(restRequest); + var content = restRequest.requiredContent(); + var inferTimeout = parseTimeout(restRequest); - InferenceAction.Request.Builder requestBuilder; - try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); - } + var request = new InferenceActionProxy.Request( + params.taskType(), + params.inferenceEntityId(), + content, + restRequest.getXContentType(), + inferTimeout, + shouldStream() + ); - var inferTimeout = parseTimeout(restRequest); - requestBuilder.setInferenceTimeout(inferTimeout); - var request = prepareInferenceRequest(requestBuilder); - return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); + return channel -> client.execute(InferenceActionProxy.INSTANCE, request, ActionListener.withRef(listener(channel), content)); } - protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { - return builder.build(); - } + protected abstract boolean shouldStream(); protected abstract ActionListener listener(RestChannel channel); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 7f43676dfb5f0..b1edec79dfb72 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -24,22 +24,14 @@ public final class Paths { static final String INFERENCE_SERVICES_PATH = "_inference/_services"; static final String TASK_TYPE_INFERENCE_SERVICES_PATH = "_inference/_services/{" + TASK_TYPE + "}"; - static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_stream"; + public static final String STREAM_SUFFIX = "_stream"; + static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + STREAM_SUFFIX; static final String STREAM_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" - + TASK_TYPE_OR_INFERENCE_ID - + "}/{" - + INFERENCE_ID - + "}/_stream"; - - // TODO remove the _unified path - public static final String UNIFIED_SUFFIX = "_unified"; - static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX; - static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID + "}/" - + UNIFIED_SUFFIX; + + STREAM_SUFFIX; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index 0fbc2f8214cbb..55083dcd4c888 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -32,6 +32,11 @@ public List routes() { return List.of(new Route(POST, INFERENCE_ID_PATH), new Route(POST, TASK_TYPE_INFERENCE_ID_PATH)); } + @Override + protected boolean shouldStream() { + return false; + } + @Override protected ActionListener listener(RestChannel channel) { return new RestChunkedToXContentListener<>(channel); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 518056365d88b..f37f4e9fb1f9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -9,17 +9,12 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; -import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; -import java.io.IOException; import java.util.List; import java.util.Objects; @@ -46,41 +41,13 @@ public List routes() { return List.of(new Route(POST, STREAM_INFERENCE_ID_PATH), new Route(POST, STREAM_TASK_TYPE_INFERENCE_ID_PATH)); } - @Override - protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { - return builder.setStream(true).build(); - } - @Override protected ActionListener listener(RestChannel channel) { return new ServerSentEventsRestActionListener(channel, threadPool); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - var params = parseParams(restRequest); - var inferTimeout = parseTimeout(restRequest); - - if (params.taskType() == TaskType.CHAT_COMPLETION) { - UnifiedCompletionAction.Request request; - try (var parser = restRequest.contentParser()) { - request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); - } - - return channel -> client.execute( - UnifiedCompletionAction.INSTANCE, - request, - new ServerSentEventsRestActionListener(channel, threadPool) - ); - } else { - InferenceAction.Request.Builder requestBuilder; - try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); - } - - requestBuilder.setInferenceTimeout(inferTimeout); - var request = prepareInferenceRequest(requestBuilder); - return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); - } + protected boolean shouldStream() { + return true; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java deleted file mode 100644 index 0efd31a6832cc..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.rest; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.Scope; -import org.elasticsearch.rest.ServerlessScope; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.rest.RestRequest.Method.POST; -import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; -import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; - -@ServerlessScope(Scope.PUBLIC) -public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { - private final SetOnce threadPool; - - public RestUnifiedCompletionInferenceAction(SetOnce threadPool) { - super(); - this.threadPool = Objects.requireNonNull(threadPool); - } - - @Override - public String getName() { - return "unified_inference_action"; - } - - @Override - public List routes() { - return List.of(new Route(POST, UNIFIED_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); - } - - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - var params = BaseInferenceAction.parseParams(restRequest); - - var inferTimeout = BaseInferenceAction.parseTimeout(restRequest); - - UnifiedCompletionAction.Request request; - try (var parser = restRequest.contentParser()) { - request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); - } - - return channel -> client.execute( - UnifiedCompletionAction.INSTANCE, - request, - new ServerSentEventsRestActionListener(channel, threadPool).delegateResponse((l, e) -> { - // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException - l.onFailure(UnifiedChatCompletionException.fromThrowable(e)); - }) - ); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1ddae3cc8df95..13d641101a1cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -42,7 +42,7 @@ import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED; import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS; -import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_SUFFIX; +import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_SUFFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; public final class ServiceUtils { @@ -796,7 +796,7 @@ public static String useChatCompletionUrlMessage(Model model) { model.getTaskType(), model.getTaskType(), model.getInferenceEntityId(), - UNIFIED_SUFFIX + STREAM_SUFFIX ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index e4de3d6beb800..1c4306c4edd46 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -29,7 +29,7 @@ public static ModelValidator buildModelValidator(TaskType taskType) { case SPARSE_EMBEDDING, RERANK, ANY -> { return new SimpleModelValidator(new SimpleServiceIntegrationValidator()); } - default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model of for task type %s ", taskType)); + default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model for task type %s", taskType)); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java new file mode 100644 index 0000000000000..a9e6ec55a6224 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; + +import java.util.Collections; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportInferenceActionProxyTests extends ESTestCase { + private Client client; + private ThreadPool threadPool; + private TransportInferenceActionProxy action; + private ModelRegistry modelRegistry; + + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + threadPool = new TestThreadPool("test"); + when(client.threadPool()).thenReturn(threadPool); + modelRegistry = mock(ModelRegistry.class); + + action = new TransportInferenceActionProxy(mock(TransportService.class), mock(ActionFilters.class), modelRegistry, client); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_InRequest() { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ] + } + ] + } + """; + + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) mock(ActionListener.class); + var request = new InferenceActionProxy.Request( + TaskType.CHAT_COMPLETION, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(UnifiedCompletionAction.INSTANCE), any(), any()); + } + + public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_FromStorage() { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ] + } + ] + } + """; + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse( + new UnparsedModel("id", TaskType.CHAT_COMPLETION, "service", Collections.emptyMap(), Collections.emptyMap()) + ); + + return Void.TYPE; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = new PlainActionFuture(); + var request = new InferenceActionProxy.Request( + TaskType.ANY, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(UnifiedCompletionAction.INSTANCE), any(), any()); + } + + public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_InRequest() { + String requestJson = """ + { + "input": ["some text"] + } + """; + + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) mock(ActionListener.class); + var request = new InferenceActionProxy.Request( + TaskType.COMPLETION, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(InferenceAction.INSTANCE), any(), any()); + } + + public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_FromStorage() { + String requestJson = """ + { + "input": ["some text"] + } + """; + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new UnparsedModel("id", TaskType.COMPLETION, "service", Collections.emptyMap(), Collections.emptyMap())); + + return Void.TYPE; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = new PlainActionFuture(); + var request = new InferenceActionProxy.Request( + TaskType.ANY, + "id", + new BytesArray(requestJson), + XContentType.JSON, + TimeValue.ONE_MINUTE, + true + ); + + action.doExecute(mock(Task.class), request, listener); + + verify(client, times(1)).execute(eq(InferenceAction.INSTANCE), any(), any()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 5528c80066b0a..4961778a03726 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.junit.Before; @@ -42,6 +43,11 @@ public class BaseInferenceActionTests extends RestActionTestCase { @Before public void setUpAction() { controller().registerHandler(new BaseInferenceAction() { + @Override + protected boolean shouldStream() { + return false; + } + @Override protected ActionListener listener(RestChannel channel) { return new RestChunkedToXContentListener<>(channel); @@ -102,10 +108,10 @@ public void testParseTimeout_ReturnsDefaultTimeout() { public void testUsesDefaultTimeout() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); + var request = (InferenceActionProxy.Request) actionRequest; + assertThat(request.getTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT)); executeCalled.set(true); return createResponse(); @@ -122,10 +128,10 @@ public void testUsesDefaultTimeout() { public void testUses3SecondTimeoutFromParams() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; - assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3))); + var request = (InferenceActionProxy.Request) actionRequest; + assertThat(request.getTimeout(), is(TimeValue.timeValueSeconds(3))); executeCalled.set(true); return createResponse(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionProxyTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionProxyTests.java index 1b0df1b4a20da..433e33fe15210 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionProxyTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -21,7 +21,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -public class RestInferenceActionTests extends RestActionTestCase { +public class RestInferenceActionProxyTests extends RestActionTestCase { @Before public void setUpAction() { @@ -31,9 +31,9 @@ public void setUpAction() { public void testStreamIsFalse() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; + var request = (InferenceActionProxy.Request) actionRequest; assertThat(request.isStreaming(), is(false)); executeCalled.set(true); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java index f67680ef6b625..e69dd3fda6240 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -9,13 +9,18 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.junit.After; import org.junit.Before; @@ -42,9 +47,9 @@ public void tearDownAction() { public void testStreamIsTrue() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(InferenceAction.Request.class)); + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); - var request = (InferenceAction.Request) actionRequest; + var request = (InferenceActionProxy.Request) actionRequest; assertThat(request.isStreaming(), is(true)); executeCalled.set(true); @@ -58,4 +63,50 @@ public void testStreamIsTrue() { dispatchRequest(inferenceRequest); assertThat(executeCalled.get(), equalTo(true)); } + + public void testStreamIsTrue_ChatCompletion() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); + + var request = (InferenceActionProxy.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + var requestBody = """ + { + "messages": [ + { + "content": "abc", + "role": "user" + } + ] + } + """; + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/chat_completion/test/_stream") + .withContent(new BytesArray(requestBody), XContentType.JSON) + .build(); + + final SetOnce responseSetOnce = new SetOnce<>(); + dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { + @Override + public void sendResponse(RestResponse response) { + responseSetOnce.set(response); + } + }); + + // the response content will be null when there is no error + assertNull(responseSetOnce.get().content()); + assertThat(executeCalled.get(), equalTo(true)); + } + + private void dispatchRequest(final RestRequest request, final RestChannel channel) { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + controller().dispatchRequest(request, channel, threadContext); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java deleted file mode 100644 index 9dc23c890c14d..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.rest; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestChannel; -import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.test.rest.FakeRestRequest; -import org.elasticsearch.test.rest.RestActionTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; -import org.junit.After; -import org.junit.Before; - -import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; - -public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { - private final SetOnce threadPool = new SetOnce<>(); - - @Before - public void setUpAction() { - threadPool.set(new TestThreadPool(getTestName())); - controller().registerHandler(new RestUnifiedCompletionInferenceAction(threadPool)); - } - - @After - public void tearDownAction() { - terminate(threadPool.get()); - } - - public void testStreamIsTrue() { - SetOnce executeCalled = new SetOnce<>(); - verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { - assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class)); - - var request = (UnifiedCompletionAction.Request) actionRequest; - assertThat(request.isStreaming(), is(true)); - - executeCalled.set(true); - return createResponse(); - })); - - var requestBody = """ - { - "messages": [ - { - "content": "abc", - "role": "user" - } - ] - } - """; - - RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath("_inference/completion/test/_unified") - .withContent(new BytesArray(requestBody), XContentType.JSON) - .build(); - - final SetOnce responseSetOnce = new SetOnce<>(); - dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { - @Override - public void sendResponse(RestResponse response) { - responseSetOnce.set(response); - } - }); - - // the response content will be null when there is no error - assertNull(responseSetOnce.get().content()); - assertThat(executeCalled.get(), equalTo(true)); - } - - private void dispatchRequest(final RestRequest request, final RestChannel channel) { - ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - controller().dispatchRequest(request, channel, threadContext); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index cdbf4c678d793..e7eccc371f941 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -444,7 +444,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws "Inference entity [model_id] does not support task type [chat_completion] " + "for inference, the task type must be one of [sparse_embedding]. " + "The task type for the inference entity is chat_completion, " - + "please use the _inference/chat_completion/model_id/_unified URL." + + "please use the _inference/chat_completion/model_id/_stream URL." ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index feb74e8524c87..13aff10e3148a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -936,7 +936,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws "Inference entity [model_id] does not support task type [chat_completion] " + "for inference, the task type must be one of [text_embedding, completion]. " + "The task type for the inference entity is chat_completion, " - + "please use the _inference/chat_completion/model_id/_unified URL." + + "please use the _inference/chat_completion/model_id/_stream URL." ) ); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 5d3b13b9d451a..7e571b6db2f92 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -327,6 +327,8 @@ public class Constants { "cluster:admin/xpack/watcher/settings/update", "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", + "cluster:internal/xpack/inference", + "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/ml/coordinatedinference", "cluster:internal/xpack/ml/datafeed/isolate", "cluster:internal/xpack/ml/datafeed/running_state", @@ -386,9 +388,8 @@ public class Constants { "cluster:monitor/xpack/enrich/stats", "cluster:monitor/xpack/eql/stats/dist", "cluster:monitor/xpack/esql/stats/dist", - "cluster:monitor/xpack/inference", + "cluster:monitor/xpack/inference/post", "cluster:monitor/xpack/inference/get", - "cluster:monitor/xpack/inference/unified", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get", "cluster:monitor/xpack/info",