Skip to content

Commit e99a781

Browse files
Adding proxy action
1 parent 534e171 commit e99a781

File tree

21 files changed

+552
-242
lines changed

21 files changed

+552
-242
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
public class InferenceAction extends ActionType<InferenceAction.Response> {
4848

4949
public static final InferenceAction INSTANCE = new InferenceAction();
50-
public static final String NAME = "cluster:monitor/xpack/inference";
50+
public static final String NAME = "cluster:internal/xpack/inference";
5151

5252
public InferenceAction() {
5353
super(NAME);
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ActionType;
13+
import org.elasticsearch.common.bytes.BytesReference;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.xcontent.XContentHelper;
17+
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.xcontent.XContentType;
20+
21+
import java.io.IOException;
22+
import java.util.Objects;
23+
24+
/**
25+
* This action is used when making a REST request to the inference API. The transport handler
26+
* will then look at the task type in the params (or retrieve it from the persisted model if it wasn't
27+
* included in the params) to determine where this request should be routed. If the task type is chat completion
28+
* then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}.
29+
* If not, it will be passed along to {@link InferenceAction}.
30+
*/
31+
public class InferenceActionProxy extends ActionType<InferenceAction.Response> {
32+
public static final InferenceActionProxy INSTANCE = new InferenceActionProxy();
33+
public static final String NAME = "cluster:monitor/xpack/inference";
34+
35+
public InferenceActionProxy() {
36+
super(NAME);
37+
}
38+
39+
public static class Request extends ActionRequest {
40+
41+
private final TaskType taskType;
42+
private final String inferenceEntityId;
43+
private final BytesReference content;
44+
private final XContentType contentType;
45+
private final TimeValue timeout;
46+
private final boolean stream;
47+
48+
public Request(
49+
TaskType taskType,
50+
String inferenceEntityId,
51+
BytesReference content,
52+
XContentType contentType,
53+
TimeValue timeout,
54+
boolean stream
55+
) {
56+
this.taskType = taskType;
57+
this.inferenceEntityId = inferenceEntityId;
58+
this.content = content;
59+
this.contentType = contentType;
60+
this.timeout = timeout;
61+
this.stream = stream;
62+
}
63+
64+
public Request(StreamInput in) throws IOException {
65+
super(in);
66+
this.taskType = TaskType.fromStream(in);
67+
this.inferenceEntityId = in.readString();
68+
this.content = in.readBytesReference();
69+
this.contentType = in.readEnum(XContentType.class);
70+
this.timeout = in.readTimeValue();
71+
72+
// streaming is not supported yet for transport traffic
73+
this.stream = false;
74+
}
75+
76+
public TaskType getTaskType() {
77+
return taskType;
78+
}
79+
80+
public String getInferenceEntityId() {
81+
return inferenceEntityId;
82+
}
83+
84+
public BytesReference getContent() {
85+
return content;
86+
}
87+
88+
public XContentType getContentType() {
89+
return contentType;
90+
}
91+
92+
public TimeValue getTimeout() {
93+
return timeout;
94+
}
95+
96+
public boolean isStreaming() {
97+
return stream;
98+
}
99+
100+
@Override
101+
public ActionRequestValidationException validate() {
102+
// TODO confirm that we don't need any validation
103+
return null;
104+
}
105+
106+
@Override
107+
public void writeTo(StreamOutput out) throws IOException {
108+
super.writeTo(out);
109+
out.writeString(inferenceEntityId);
110+
taskType.writeTo(out);
111+
out.writeBytesReference(content);
112+
XContentHelper.writeTo(out, contentType);
113+
out.writeTimeValue(timeout);
114+
}
115+
116+
@Override
117+
public boolean equals(Object o) {
118+
if (this == o) return true;
119+
if (o == null || getClass() != o.getClass()) return false;
120+
Request request = (Request) o;
121+
return taskType == request.taskType
122+
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
123+
&& Objects.equals(content, request.content)
124+
&& contentType == request.contentType;
125+
}
126+
127+
@Override
128+
public int hashCode() {
129+
return Objects.hash(taskType, inferenceEntityId, content, contentType);
130+
}
131+
}
132+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public class UnifiedCompletionAction extends ActionType<InferenceAction.Response> {
2323
public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction();
24-
public static final String NAME = "cluster:monitor/xpack/inference/unified";
24+
public static final String NAME = "cluster:internal/xpack/inference/unified";
2525

2626
public UnifiedCompletionAction() {
2727
super(NAME);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4163,7 +4163,6 @@ public void testInferenceUserRole() {
41634163
assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication));
41644164
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication));
41654165
assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication));
4166-
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication));
41674166
assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication));
41684167
assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication));
41694168
assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
5959
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
6060
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
61+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
6162
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
6263
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
6364
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
@@ -67,6 +68,7 @@
6768
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
6869
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
6970
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
71+
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
7072
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
7173
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
7274
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
@@ -104,7 +106,6 @@
104106
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
105107
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
106108
import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction;
107-
import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction;
108109
import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction;
109110
import org.elasticsearch.xpack.inference.services.ServiceComponents;
110111
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
@@ -195,6 +196,7 @@ public InferencePlugin(Settings settings) {
195196
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
196197
return List.of(
197198
new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class),
199+
new ActionHandler<>(InferenceActionProxy.INSTANCE, TransportInferenceActionProxy.class),
198200
new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class),
199201
new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class),
200202
new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class),
@@ -226,8 +228,7 @@ public List<RestHandler> getRestHandlers(
226228
new RestUpdateInferenceModelAction(),
227229
new RestDeleteInferenceEndpointAction(),
228230
new RestGetInferenceDiagnosticsAction(),
229-
new RestGetInferenceServicesAction(),
230-
new RestUnifiedCompletionInferenceAction(threadPoolSetOnce)
231+
new RestGetInferenceServicesAction()
231232
);
232233
}
233234

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.support.ActionFilters;
13+
import org.elasticsearch.action.support.HandledTransportAction;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.common.util.concurrent.EsExecutors;
16+
import org.elasticsearch.common.xcontent.XContentHelper;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.inference.UnparsedModel;
19+
import org.elasticsearch.injection.guice.Inject;
20+
import org.elasticsearch.rest.RestStatus;
21+
import org.elasticsearch.tasks.Task;
22+
import org.elasticsearch.transport.TransportService;
23+
import org.elasticsearch.xcontent.XContentParserConfiguration;
24+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
25+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
26+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
27+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
28+
29+
import java.io.IOException;
30+
31+
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
32+
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
33+
34+
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
35+
private final ModelRegistry modelRegistry;
36+
private final Client client;
37+
38+
@Inject
39+
public TransportInferenceActionProxy(
40+
TransportService transportService,
41+
ActionFilters actionFilters,
42+
ModelRegistry modelRegistry,
43+
Client client
44+
) {
45+
super(
46+
InferenceActionProxy.NAME,
47+
transportService,
48+
actionFilters,
49+
InferenceActionProxy.Request::new,
50+
EsExecutors.DIRECT_EXECUTOR_SERVICE
51+
);
52+
53+
this.modelRegistry = modelRegistry;
54+
this.client = client;
55+
}
56+
57+
@Override
58+
protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) {
59+
try {
60+
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> {
61+
if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) {
62+
sendUnifiedCompletionRequest(request, l);
63+
} else {
64+
sendInferenceActionRequest(request, l);
65+
}
66+
});
67+
68+
if (request.getTaskType() == TaskType.ANY) {
69+
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
70+
} else if (request.getTaskType() == TaskType.CHAT_COMPLETION) {
71+
sendUnifiedCompletionRequest(request, listener);
72+
} else {
73+
sendInferenceActionRequest(request, listener);
74+
}
75+
} catch (Exception e) {
76+
listener.onFailure(e);
77+
}
78+
}
79+
80+
private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener)
81+
throws IOException {
82+
83+
if (request.isStreaming() == false) {
84+
throw new ElasticsearchStatusException(
85+
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
86+
RestStatus.BAD_REQUEST
87+
);
88+
}
89+
90+
UnifiedCompletionAction.Request unifiedRequest;
91+
try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) {
92+
unifiedRequest = UnifiedCompletionAction.Request.parseRequest(
93+
request.getInferenceEntityId(),
94+
request.getTaskType(),
95+
request.getTimeout(),
96+
parser
97+
);
98+
}
99+
100+
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
101+
}
102+
103+
private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener)
104+
throws IOException {
105+
InferenceAction.Request.Builder inferenceActionRequestBuilder;
106+
try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) {
107+
inferenceActionRequestBuilder = InferenceAction.Request.parseRequest(
108+
request.getInferenceEntityId(),
109+
request.getTaskType(),
110+
parser
111+
);
112+
inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
113+
}
114+
115+
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
116+
}
117+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.rest.RestChannel;
1616
import org.elasticsearch.rest.RestRequest;
1717
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
18+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
1819

1920
import java.io.IOException;
2021

@@ -41,21 +42,22 @@ static TimeValue parseTimeout(RestRequest restRequest) {
4142
@Override
4243
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
4344
var params = parseParams(restRequest);
45+
var content = restRequest.requiredContent();
46+
var inferTimeout = parseTimeout(restRequest);
4447

45-
InferenceAction.Request.Builder requestBuilder;
46-
try (var parser = restRequest.contentParser()) {
47-
requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser);
48-
}
48+
var request = new InferenceActionProxy.Request(
49+
params.taskType(),
50+
params.inferenceEntityId(),
51+
content,
52+
restRequest.getXContentType(),
53+
inferTimeout,
54+
shouldStream()
55+
);
4956

50-
var inferTimeout = parseTimeout(restRequest);
51-
requestBuilder.setInferenceTimeout(inferTimeout);
52-
var request = prepareInferenceRequest(requestBuilder);
53-
return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel));
57+
return channel -> client.execute(InferenceActionProxy.INSTANCE, request, listener(channel));
5458
}
5559

56-
protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) {
57-
return builder.build();
58-
}
60+
protected abstract boolean shouldStream();
5961

6062
protected abstract ActionListener<InferenceAction.Response> listener(RestChannel channel);
6163
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,13 @@ public final class Paths {
2424
static final String INFERENCE_SERVICES_PATH = "_inference/_services";
2525
static final String TASK_TYPE_INFERENCE_SERVICES_PATH = "_inference/_services/{" + TASK_TYPE + "}";
2626

27-
static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_stream";
27+
public static final String STREAM_SUFFIX = "_stream";
28+
static final String STREAM_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + STREAM_SUFFIX;
2829
static final String STREAM_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"
2930
+ TASK_TYPE_OR_INFERENCE_ID
3031
+ "}/{"
3132
+ INFERENCE_ID
32-
+ "}/_stream";
33-
34-
// TODO remove the _unified path
35-
public static final String UNIFIED_SUFFIX = "_unified";
36-
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX;
37-
static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"
38-
+ TASK_TYPE_OR_INFERENCE_ID
39-
+ "}/{"
40-
+ INFERENCE_ID
41-
+ "}/"
42-
+ UNIFIED_SUFFIX;
33+
+ "}/" + STREAM_SUFFIX;
4334

4435
private Paths() {
4536

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ public List<Route> routes() {
3232
return List.of(new Route(POST, INFERENCE_ID_PATH), new Route(POST, TASK_TYPE_INFERENCE_ID_PATH));
3333
}
3434

35+
@Override
36+
protected boolean shouldStream() {
37+
return false;
38+
}
39+
3540
@Override
3641
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
3742
return new RestChunkedToXContentListener<>(channel);

0 commit comments

Comments
 (0)