Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
public class InferenceAction extends ActionType<InferenceAction.Response> {

public static final InferenceAction INSTANCE = new InferenceAction();
public static final String NAME = "cluster:monitor/xpack/inference";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle just wanted to confirm that this is what we want to do here right? Changing it to internal?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ yes internal is good

public static final String NAME = "cluster:internal/xpack/inference";

public InferenceAction() {
super(NAME);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;
import java.util.Objects;

/**
* This action is used when making a REST request to the inference API. The transport handler
* will then look at the task type in the params (or retrieve it from the persisted model if it wasn't
* included in the params) to determine where this request should be routed. If the task type is chat completion
* then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}.
* If not, it will be passed along to {@link InferenceAction}.
*/
public class InferenceActionProxy extends ActionType<InferenceAction.Response> {
public static final InferenceActionProxy INSTANCE = new InferenceActionProxy();
public static final String NAME = "cluster:monitor/xpack/inference";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is safe to use the old InferenceAction name can be used here when it has different request and response classes. I'm thinking about a mixed cluster. To be safe please give it a new name.

Get, Put and Delete are cluster:monitor/xpack/inference/[get|put|delete] this could be cluster:monitor/xpack/inference/post or cluster:monitor/xpack/inference/inference`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I'll switch to post 👍


public InferenceActionProxy() {
super(NAME);
}

public static class Request extends ActionRequest {

private final TaskType taskType;
private final String inferenceEntityId;
private final BytesReference content;
private final XContentType contentType;
private final TimeValue timeout;
private final boolean stream;

public Request(
TaskType taskType,
String inferenceEntityId,
BytesReference content,
XContentType contentType,
TimeValue timeout,
boolean stream
) {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
this.stream = stream;
}

public Request(StreamInput in) throws IOException {
super(in);
this.taskType = TaskType.fromStream(in);
this.inferenceEntityId = in.readString();
this.content = in.readBytesReference();
this.contentType = in.readEnum(XContentType.class);
this.timeout = in.readTimeValue();

// streaming is not supported yet for transport traffic
this.stream = false;
}

public TaskType getTaskType() {
return taskType;
}

public String getInferenceEntityId() {
return inferenceEntityId;
}

public BytesReference getContent() {
return content;
}

public XContentType getContentType() {
return contentType;
}

public TimeValue getTimeout() {
return timeout;
}

public boolean isStreaming() {
return stream;
}

@Override
public ActionRequestValidationException validate() {
// TODO confirm that we don't need any validation
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);
out.writeTimeValue(timeout);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(content, request.content)
&& contentType == request.contentType;
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

public class UnifiedCompletionAction extends ActionType<InferenceAction.Response> {
public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction();
public static final String NAME = "cluster:monitor/xpack/inference/unified";
public static final String NAME = "cluster:internal/xpack/inference/unified";

public UnifiedCompletionAction() {
super(NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4163,7 +4163,6 @@ public void testInferenceUserRole() {
assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication));
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication));
assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication));
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication));
assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication));
assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication));
assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
List<String> input,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route
var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route);
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
Expand All @@ -67,6 +68,7 @@
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
Expand Down Expand Up @@ -104,7 +106,6 @@
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
Expand Down Expand Up @@ -195,6 +196,7 @@ public InferencePlugin(Settings settings) {
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return List.of(
new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class),
new ActionHandler<>(InferenceActionProxy.INSTANCE, TransportInferenceActionProxy.class),
new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class),
new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class),
new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class),
Expand Down Expand Up @@ -226,8 +228,7 @@ public List<RestHandler> getRestHandlers(
new RestUpdateInferenceModelAction(),
new RestDeleteInferenceEndpointAction(),
new RestGetInferenceDiagnosticsAction(),
new RestGetInferenceServicesAction(),
new RestUnifiedCompletionInferenceAction(threadPoolSetOnce)
new RestGetInferenceServicesAction()
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

import java.io.IOException;

import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
private final ModelRegistry modelRegistry;
private final Client client;

@Inject
public TransportInferenceActionProxy(
TransportService transportService,
ActionFilters actionFilters,
ModelRegistry modelRegistry,
Client client
) {
super(
InferenceActionProxy.NAME,
transportService,
actionFilters,
InferenceActionProxy.Request::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);

this.modelRegistry = modelRegistry;
this.client = client;
}

@Override
protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) {
try {
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> {
if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) {
sendUnifiedCompletionRequest(request, l);
} else {
sendInferenceActionRequest(request, l);
}
});

if (request.getTaskType() == TaskType.ANY) {
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
} else if (request.getTaskType() == TaskType.CHAT_COMPLETION) {
sendUnifiedCompletionRequest(request, listener);
} else {
sendInferenceActionRequest(request, listener);
}
} catch (Exception e) {
listener.onFailure(e);
}
}

private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) {
// format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException
var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


try {
if (request.isStreaming() == false) {
throw new ElasticsearchStatusException(
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
RestStatus.BAD_REQUEST
);
}

UnifiedCompletionAction.Request unifiedRequest;
try (
var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())
) {
unifiedRequest = UnifiedCompletionAction.Request.parseRequest(
request.getInferenceEntityId(),
request.getTaskType(),
request.getTimeout(),
parser
);
}

executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener);
} catch (Exception e) {
unifiedErrorFormatListener.onFailure(e);
}
}

private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener)
throws IOException {
InferenceAction.Request.Builder inferenceActionRequestBuilder;
try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) {
inferenceActionRequestBuilder = InferenceAction.Request.parseRequest(
request.getInferenceEntityId(),
request.getTaskType(),
parser
);
inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
}

executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;

import java.io.IOException;

Expand All @@ -41,21 +42,22 @@ static TimeValue parseTimeout(RestRequest restRequest) {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
var params = parseParams(restRequest);
var content = restRequest.requiredContent();
var inferTimeout = parseTimeout(restRequest);

InferenceAction.Request.Builder requestBuilder;
try (var parser = restRequest.contentParser()) {
requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser);
}
var request = new InferenceActionProxy.Request(
params.taskType(),
params.inferenceEntityId(),
content,
restRequest.getXContentType(),
inferTimeout,
shouldStream()
);

var inferTimeout = parseTimeout(restRequest);
requestBuilder.setInferenceTimeout(inferTimeout);
var request = prepareInferenceRequest(requestBuilder);
return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel));
return channel -> client.execute(InferenceActionProxy.INSTANCE, request, ActionListener.withRef(listener(channel), content));
}

protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) {
return builder.build();
}
protected abstract boolean shouldStream();

protected abstract ActionListener<InferenceAction.Response> listener(RestChannel channel);
}
Loading