Skip to content

Commit 1e0eb20

Browse files
Adding separate transport classes
1 parent 1e30c6d commit 1e0eb20

File tree

12 files changed

+212
-314
lines changed

12 files changed

+212
-314
lines changed
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,24 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.action;
8+
package org.elasticsearch.xpack.core.inference.action;
99

1010
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.common.io.stream.StreamInput;
1112
import org.elasticsearch.inference.TaskType;
1213

14+
import java.io.IOException;
15+
1316
public abstract class BaseInferenceActionRequest extends ActionRequest {
17+
18+
public BaseInferenceActionRequest() {
19+
super();
20+
}
21+
22+
public BaseInferenceActionRequest(StreamInput in) throws IOException {
23+
super(in);
24+
}
25+
1426
public abstract boolean isStreaming();
1527

1628
public abstract TaskType getTaskType();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
13-
import org.elasticsearch.action.ActionRequest;
1413
import org.elasticsearch.action.ActionRequestValidationException;
1514
import org.elasticsearch.action.ActionResponse;
1615
import org.elasticsearch.action.ActionType;
@@ -55,7 +54,7 @@ public InferenceAction() {
5554
super(NAME);
5655
}
5756

58-
public static class Request extends ActionRequest {
57+
public static class Request extends BaseInferenceActionRequest {
5958

6059
public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30);
6160
public static final ParseField INPUT = new ParseField("input");

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10-
import org.elasticsearch.action.ActionRequest;
1110
import org.elasticsearch.action.ActionRequestValidationException;
1211
import org.elasticsearch.action.ActionType;
1312
import org.elasticsearch.common.io.stream.StreamInput;
1413
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.core.TimeValue;
1515
import org.elasticsearch.inference.TaskType;
1616
import org.elasticsearch.xcontent.XContentParser;
1717

@@ -26,27 +26,30 @@ public UnifiedCompletionAction() {
2626
super(NAME);
2727
}
2828

29-
public static class Request extends ActionRequest {
30-
public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
29+
public static class Request extends BaseInferenceActionRequest {
30+
public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) throws IOException {
3131
var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null);
32-
return new Request(inferenceEntityId, taskType, unifiedRequest);
32+
return new Request(inferenceEntityId, taskType, unifiedRequest, timeout);
3333
}
3434

3535
private final String inferenceEntityId;
3636
private final TaskType taskType;
3737
private final UnifiedCompletionRequest unifiedCompletionRequest;
38+
private final TimeValue timeout;
3839

39-
public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest) {
40+
public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) {
4041
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
4142
this.taskType = Objects.requireNonNull(taskType);
4243
this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest);
44+
this.timeout = Objects.requireNonNull(timeout);
4345
}
4446

4547
public Request(StreamInput in) throws IOException {
4648
super(in);
4749
this.inferenceEntityId = in.readString();
4850
this.taskType = TaskType.fromStream(in);
4951
this.unifiedCompletionRequest = new UnifiedCompletionRequest(in);
52+
this.timeout = in.readTimeValue();
5053
}
5154

5255
public TaskType getTaskType() {
@@ -62,7 +65,11 @@ public UnifiedCompletionRequest getUnifiedCompletionRequest() {
6265
}
6366

6467
public boolean isStreaming() {
65-
return Objects.requireNonNullElse(unifiedCompletionRequest.stream(), false);
68+
return true;
69+
}
70+
71+
public TimeValue getTimeout() {
72+
return timeout;
6673
}
6774

6875
@Override
@@ -94,6 +101,7 @@ public void writeTo(StreamOutput out) throws IOException {
94101
out.writeString(inferenceEntityId);
95102
taskType.writeTo(out);
96103
unifiedCompletionRequest.writeTo(out);
104+
out.writeTimeValue(timeout);
97105
}
98106

99107
@Override
@@ -102,12 +110,13 @@ public boolean equals(Object o) {
102110
Request request = (Request) o;
103111
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
104112
&& taskType == request.taskType
105-
&& Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest);
113+
&& Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) &&
114+
Objects.equals(timeout, request.timeout);
106115
}
107116

108117
@Override
109118
public int hashCode() {
110-
return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest);
119+
return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest, timeout);
111120
}
112121
}
113122

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ public record UnifiedCompletionRequest(
3232
@Nullable Long maxCompletionTokens,
3333
@Nullable Integer n,
3434
@Nullable Stop stop,
35-
@Nullable Boolean stream,
3635
@Nullable Float temperature,
3736
@Nullable ToolChoice toolChoice,
3837
@Nullable List<Tool> tool,
@@ -49,7 +48,6 @@ public record UnifiedCompletionRequest(
4948
(Long) args[2],
5049
(Integer) args[3],
5150
(Stop) args[4],
52-
(Boolean) args[5],
5351
(Float) args[6],
5452
(ToolChoice) args[7],
5553
(List<Tool>) args[8],
@@ -64,7 +62,6 @@ public record UnifiedCompletionRequest(
6462
PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens"));
6563
PARSER.declareInt(optionalConstructorArg(), new ParseField("n"));
6664
PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY);
67-
PARSER.declareBoolean(optionalConstructorArg(), new ParseField("stream"));
6865
PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature"));
6966
PARSER.declareField(
7067
optionalConstructorArg(),
@@ -84,7 +81,6 @@ public UnifiedCompletionRequest(StreamInput in) throws IOException {
8481
in.readOptionalVLong(),
8582
in.readOptionalVInt(),
8683
in.readOptionalNamedWriteable(Stop.class),
87-
in.readOptionalBoolean(),
8884
in.readOptionalFloat(),
8985
in.readOptionalNamedWriteable(ToolChoice.class),
9086
in.readCollectionAsImmutableList(Tool::new),
@@ -100,7 +96,6 @@ public void writeTo(StreamOutput out) throws IOException {
10096
out.writeOptionalVLong(maxCompletionTokens);
10197
out.writeOptionalVInt(n);
10298
out.writeOptionalNamedWriteable(stop);
103-
out.writeOptionalBoolean(stream);
10499
out.writeOptionalFloat(temperature);
105100
out.writeOptionalNamedWriteable(toolChoice);
106101
out.writeOptionalCollection(tool);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4236,6 +4236,7 @@ public void testInferenceUserRole() {
42364236
assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication));
42374237
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication));
42384238
assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication));
4239+
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication));
42394240
assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication));
42404241
assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication));
42414242
assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
4949
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
5050
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
51+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
5152
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
5253
import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction;
5354
import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction;
@@ -56,6 +57,7 @@
5657
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
5758
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
5859
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
60+
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
5961
import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
6062
import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
6163
import org.elasticsearch.xpack.inference.common.Truncator;
@@ -152,6 +154,7 @@ public InferencePlugin(Settings settings) {
152154
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
153155
return List.of(
154156
new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class),
157+
new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
155158
new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class),
156159
new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class),
157160
new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import org.elasticsearch.inference.Model;
2424
import org.elasticsearch.inference.TaskType;
2525
import org.elasticsearch.inference.UnparsedModel;
26-
import org.elasticsearch.injection.guice.Inject;
2726
import org.elasticsearch.rest.RestStatus;
2827
import org.elasticsearch.tasks.Task;
2928
import org.elasticsearch.transport.TransportService;
29+
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
3030
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3131
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
3232
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
@@ -40,8 +40,8 @@
4040
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
4141
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
4242

43-
public abstract class BaseTransportInferenceAction<T extends BaseInferenceActionRequest> extends HandledTransportAction<
44-
T,
43+
public abstract class BaseTransportInferenceAction<Request extends BaseInferenceActionRequest> extends HandledTransportAction<
44+
Request,
4545
InferenceAction.Response> {
4646

4747
private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class);
@@ -52,16 +52,14 @@ public abstract class BaseTransportInferenceAction<T extends BaseInferenceAction
5252
private final InferenceStats inferenceStats;
5353
private final StreamingTaskManager streamingTaskManager;
5454

55-
// TODO remove the inject here?
56-
@Inject
5755
public BaseTransportInferenceAction(
5856
TransportService transportService,
5957
ActionFilters actionFilters,
6058
ModelRegistry modelRegistry,
6159
InferenceServiceRegistry serviceRegistry,
6260
InferenceStats inferenceStats,
6361
StreamingTaskManager streamingTaskManager,
64-
Writeable.Reader<T> requestReader
62+
Writeable.Reader<Request> requestReader
6563
) {
6664
super(InferenceAction.NAME, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
6765
this.modelRegistry = modelRegistry;
@@ -71,7 +69,7 @@ public BaseTransportInferenceAction(
7169
}
7270

7371
@Override
74-
protected void doExecute(Task task, T request, ActionListener<InferenceAction.Response> listener) {
72+
protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
7573
var timer = InferenceTimer.start();
7674

7775
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
@@ -92,7 +90,7 @@ protected void doExecute(Task task, T request, ActionListener<InferenceAction.Re
9290
}
9391

9492
if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
95-
var e = incompatibleUnifiedModeTaskTypeException(request.getTaskType());
93+
var e = createIncompatibleTaskTypeException(request, unparsedModel);
9694
recordMetrics(unparsedModel, timer, e);
9795
listener.onFailure(e);
9896
return;
@@ -118,22 +116,9 @@ protected void doExecute(Task task, T request, ActionListener<InferenceAction.Re
118116
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
119117
}
120118

121-
protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(T request, UnparsedModel unparsedModel);
119+
protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);
122120

123-
protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException(T request, UnparsedModel unparsedModel);
124-
125-
private boolean isInvalidTaskTypeForUnifiedCompletionMode(T request, UnparsedModel unparsedModel) {
126-
return request.isUnifiedCompletionMode() && request.getTaskType() != TaskType.COMPLETION;
127-
}
128-
129-
private static ElasticsearchStatusException incompatibleUnifiedModeTaskTypeException(TaskType requested) {
130-
return new ElasticsearchStatusException(
131-
"Incompatible task_type for unified API, the requested type [{}] must be one of [{}]",
132-
RestStatus.BAD_REQUEST,
133-
requested,
134-
TaskType.COMPLETION.toString()
135-
);
136-
}
121+
protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException(Request request, UnparsedModel unparsedModel);
137122

138123
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
139124
try {
@@ -145,7 +130,7 @@ private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable
145130

146131
private void inferOnServiceWithMetrics(
147132
Model model,
148-
InferenceAction.Request request,
133+
Request request,
149134
InferenceService service,
150135
InferenceTimer timer,
151136
ActionListener<InferenceAction.Response> listener
@@ -178,43 +163,43 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl
178163
}
179164
}
180165

181-
private void inferOnService(
182-
Model model,
183-
InferenceAction.Request request,
184-
InferenceService service,
185-
ActionListener<InferenceServiceResults> listener
186-
) {
187-
Runnable inferenceRunnable = inferRunnable(model, request, service, listener);
188-
166+
private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
189167
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
190-
inferenceRunnable.run();
168+
doInference(model, request, service, listener);
191169
} else {
192170
listener.onFailure(unsupportedStreamingTaskException(request, service));
193171
}
194172
}
195173

196-
private static Runnable inferRunnable(
174+
// private static Runnable inferRunnable(
175+
// Model model,
176+
// T request,
177+
// InferenceService service,
178+
// ActionListener<InferenceServiceResults> listener
179+
// ) {
180+
// return request.isUnifiedCompletionMode()
181+
// // TODO add parameters
182+
// ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener)
183+
// : () -> service.infer(
184+
// model,
185+
// request.getQuery(),
186+
// request.getInput(),
187+
// request.isStreaming(),
188+
// request.getTaskSettings(),
189+
// request.getInputType(),
190+
// request.getInferenceTimeout(),
191+
// listener
192+
// );
193+
// }
194+
195+
protected abstract void doInference(
197196
Model model,
198-
InferenceAction.Request request,
197+
Request request,
199198
InferenceService service,
200199
ActionListener<InferenceServiceResults> listener
201-
) {
202-
return request.isUnifiedCompletionMode()
203-
// TODO add parameters
204-
? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener)
205-
: () -> service.infer(
206-
model,
207-
request.getQuery(),
208-
request.getInput(),
209-
request.isStreaming(),
210-
request.getTaskSettings(),
211-
request.getInputType(),
212-
request.getInferenceTimeout(),
213-
listener
214-
);
215-
}
200+
);
216201

217-
private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) {
202+
private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) {
218203
var supportedTasks = service.supportedStreamingTasks();
219204
if (supportedTasks.isEmpty()) {
220205
return new ElasticsearchStatusException(

0 commit comments

Comments
 (0)