Skip to content

Commit 1e15c88

Browse files
authored
support dispatching execute task; don't dispatch ML task again (#279)
* support dispatching execute task; don't dispatch ML task again Signed-off-by: Yaliang Wu <[email protected]> * remove MLPredictTaskRunner from jacoco exclusion list Signed-off-by: Yaliang Wu <[email protected]>
1 parent f47e888 commit 1e15c88

File tree

15 files changed

+267
-145
lines changed

15 files changed

+267
-145
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport;
7+
8+
import lombok.Getter;
9+
import lombok.Setter;
10+
import org.opensearch.action.ActionRequest;
11+
import org.opensearch.action.ActionRequestValidationException;
12+
import org.opensearch.common.io.stream.StreamInput;
13+
import org.opensearch.common.io.stream.StreamOutput;
14+
15+
import java.io.IOException;
16+
import java.util.UUID;
17+
18+
@Getter
19+
@Setter
20+
public class MLTaskRequest extends ActionRequest {
21+
22+
protected boolean dispatchTask;
23+
protected final String requestID;
24+
25+
public MLTaskRequest(boolean dispatchTask) {
26+
this.dispatchTask = dispatchTask;
27+
this.requestID = UUID.randomUUID().toString();
28+
}
29+
30+
public MLTaskRequest(StreamInput in) throws IOException {
31+
super(in);
32+
this.requestID = in.readString();
33+
this.dispatchTask = in.readBoolean();
34+
}
35+
36+
@Override
37+
public void writeTo(StreamOutput out) throws IOException {
38+
super.writeTo(out);
39+
out.writeString(requestID);
40+
out.writeBoolean(dispatchTask);
41+
}
42+
43+
@Override
44+
public ActionRequestValidationException validate() {
45+
return null;
46+
}
47+
}

common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.ml.common.MLCommonsClassLoader;
2121
import org.opensearch.ml.common.FunctionName;
2222
import org.opensearch.ml.common.input.Input;
23+
import org.opensearch.ml.common.transport.MLTaskRequest;
2324

2425
import java.io.ByteArrayInputStream;
2526
import java.io.ByteArrayOutputStream;
@@ -31,17 +32,22 @@
3132
@Getter
3233
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
3334
@ToString
34-
public class MLExecuteTaskRequest extends ActionRequest {
35+
public class MLExecuteTaskRequest extends MLTaskRequest {
3536

3637
FunctionName functionName;
3738
Input input;
3839

3940
@Builder
40-
public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input) {
41+
public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input, boolean dispatchTask) {
42+
super(dispatchTask);
4143
this.functionName = functionName;
4244
this.input = input;
4345
}
4446

47+
public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input) {
48+
this(functionName, input, true);
49+
}
50+
4551
public MLExecuteTaskRequest(StreamInput in) throws IOException {
4652
super(in);
4753
this.functionName = in.readEnum(FunctionName.class);

common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,29 @@
2323
import lombok.Getter;
2424
import lombok.ToString;
2525
import lombok.experimental.FieldDefaults;
26+
import org.opensearch.ml.common.transport.MLTaskRequest;
2627

2728
import static org.opensearch.action.ValidateActions.addValidationError;
2829

2930
@Getter
3031
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
3132
@ToString
32-
public class MLPredictionTaskRequest extends ActionRequest {
33+
public class MLPredictionTaskRequest extends MLTaskRequest {
3334

3435
String modelId;
3536
MLInput mlInput;
3637

3738
@Builder
38-
public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
39+
public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatchTask) {
40+
super(dispatchTask);
3941
this.mlInput = mlInput;
4042
this.modelId = modelId;
4143
}
4244

45+
public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
46+
this(modelId, mlInput, true);
47+
}
48+
4349
public MLPredictionTaskRequest(StreamInput in) throws IOException {
4450
super(in);
4551
this.modelId = in.readOptionalString();

common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.common.io.stream.StreamInput;
1818
import org.opensearch.common.io.stream.StreamOutput;
1919
import org.opensearch.ml.common.input.MLInput;
20+
import org.opensearch.ml.common.transport.MLTaskRequest;
2021

2122
import java.io.ByteArrayInputStream;
2223
import java.io.ByteArrayOutputStream;
@@ -29,7 +30,7 @@
2930
@Getter
3031
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
3132
@ToString
32-
public class MLTrainingTaskRequest extends ActionRequest {
33+
public class MLTrainingTaskRequest extends MLTaskRequest {
3334

3435
/**
3536
* the name of algorithm
@@ -38,11 +39,16 @@ public class MLTrainingTaskRequest extends ActionRequest {
3839
boolean async;
3940

4041
@Builder
41-
public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
42+
public MLTrainingTaskRequest(MLInput mlInput, boolean async, boolean dispatchTask) {
43+
super(dispatchTask);
4244
this.mlInput = mlInput;
4345
this.async = async;
4446
}
4547

48+
public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
49+
this(mlInput, async, true);
50+
}
51+
4652
public MLTrainingTaskRequest(StreamInput in) throws IOException {
4753
super(in);
4854
this.mlInput = new MLInput(in);

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ public Collection<Object> createComponents(
189189
mlStats,
190190
mlInputDatasetHandler,
191191
mlTaskDispatcher,
192-
mlCircuitBreakerService
192+
mlCircuitBreakerService,
193+
xContentRegistry
193194
);
194195
mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner(
195196
threadPool,

plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@
1010
import lombok.extern.log4j.Log4j2;
1111

1212
import org.opensearch.action.ActionListener;
13+
import org.opensearch.action.ActionListenerResponseHandler;
1314
import org.opensearch.client.Client;
1415
import org.opensearch.cluster.service.ClusterService;
1516
import org.opensearch.ml.common.FunctionName;
1617
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
1718
import org.opensearch.ml.common.input.Input;
1819
import org.opensearch.ml.common.output.Output;
20+
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
1921
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
2022
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
2123
import org.opensearch.ml.engine.MLEngine;
2224
import org.opensearch.ml.indices.MLInputDatasetHandler;
2325
import org.opensearch.ml.stats.MLStats;
2426
import org.opensearch.threadpool.ThreadPool;
25-
import org.opensearch.transport.TransportService;
27+
import org.opensearch.transport.TransportResponseHandler;
2628

2729
/**
2830
* MLExecuteTaskRunner is responsible for running execute tasks.
@@ -44,26 +46,30 @@ public MLExecuteTaskRunner(
4446
MLTaskDispatcher mlTaskDispatcher,
4547
MLCircuitBreakerService mlCircuitBreakerService
4648
) {
47-
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService);
49+
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
4850
this.threadPool = threadPool;
4951
this.clusterService = clusterService;
5052
this.client = client;
5153
this.mlInputDatasetHandler = mlInputDatasetHandler;
5254
}
5355

56+
@Override
57+
protected String getTransportActionName() {
58+
return MLExecuteTaskAction.NAME;
59+
}
60+
61+
@Override
62+
protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(ActionListener<MLExecuteTaskResponse> listener) {
63+
return new ActionListenerResponseHandler<>(listener, MLExecuteTaskResponse::new);
64+
}
65+
5466
/**
5567
* Execute algorithm and return result.
56-
* TODO: 1. support backend task run; 2. support dispatch task to remote node
5768
* @param request MLExecuteTaskRequest
58-
* @param transportService transport service
5969
* @param listener Action listener
6070
*/
6171
@Override
62-
public void executeTask(
63-
MLExecuteTaskRequest request,
64-
TransportService transportService,
65-
ActionListener<MLExecuteTaskResponse> listener
66-
) {
72+
protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
6773
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
6874
try {
6975
Input input = request.getInput();

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.task;
77

8+
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
89
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
910
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
1011
import static org.opensearch.ml.permission.AccessController.getUserContext;
@@ -17,7 +18,6 @@
1718

1819
import java.time.Instant;
1920
import java.util.Base64;
20-
import java.util.Map;
2121
import java.util.UUID;
2222

2323
import lombok.extern.log4j.Log4j2;
@@ -32,6 +32,10 @@
3232
import org.opensearch.client.Client;
3333
import org.opensearch.cluster.service.ClusterService;
3434
import org.opensearch.common.util.concurrent.ThreadContext;
35+
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
36+
import org.opensearch.common.xcontent.NamedXContentRegistry;
37+
import org.opensearch.common.xcontent.XContentParser;
38+
import org.opensearch.common.xcontent.XContentType;
3539
import org.opensearch.commons.authuser.User;
3640
import org.opensearch.ml.common.MLModel;
3741
import org.opensearch.ml.common.MLTask;
@@ -53,7 +57,7 @@
5357
import org.opensearch.ml.stats.ActionName;
5458
import org.opensearch.ml.stats.MLStats;
5559
import org.opensearch.threadpool.ThreadPool;
56-
import org.opensearch.transport.TransportService;
60+
import org.opensearch.transport.TransportResponseHandler;
5761

5862
/**
5963
* MLPredictTaskRunner is responsible for running predict tasks.
@@ -64,6 +68,7 @@ public class MLPredictTaskRunner extends MLTaskRunner<MLPredictionTaskRequest, M
6468
private final ClusterService clusterService;
6569
private final Client client;
6670
private final MLInputDatasetHandler mlInputDatasetHandler;
71+
private final NamedXContentRegistry xContentRegistry;
6772

6873
public MLPredictTaskRunner(
6974
ThreadPool threadPool,
@@ -73,42 +78,34 @@ public MLPredictTaskRunner(
7378
MLStats mlStats,
7479
MLInputDatasetHandler mlInputDatasetHandler,
7580
MLTaskDispatcher mlTaskDispatcher,
76-
MLCircuitBreakerService mlCircuitBreakerService
81+
MLCircuitBreakerService mlCircuitBreakerService,
82+
NamedXContentRegistry xContentRegistry
7783
) {
78-
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService);
84+
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
7985
this.threadPool = threadPool;
8086
this.clusterService = clusterService;
8187
this.client = client;
8288
this.mlInputDatasetHandler = mlInputDatasetHandler;
89+
this.xContentRegistry = xContentRegistry;
8390
}
8491

8592
@Override
86-
public void executeTask(MLPredictionTaskRequest request, TransportService transportService, ActionListener<MLTaskResponse> listener) {
87-
mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> {
88-
if (clusterService.localNode().getId().equals(node.getId())) {
89-
// Execute prediction task locally
90-
log.info("execute ML prediction request {} locally on node {}", request.toString(), node.getId());
91-
startPredictionTask(request, listener);
92-
} else {
93-
// Execute batch task remotely
94-
log.info("execute ML prediction request {} remotely on node {}", request.toString(), node.getId());
95-
transportService
96-
.sendRequest(
97-
node,
98-
MLPredictionTaskAction.NAME,
99-
request,
100-
new ActionListenerResponseHandler<>(listener, MLTaskResponse::new)
101-
);
102-
}
103-
}, e -> listener.onFailure(e)));
93+
protected String getTransportActionName() {
94+
return MLPredictionTaskAction.NAME;
95+
}
96+
97+
@Override
98+
protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
99+
return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new);
104100
}
105101

106102
/**
107103
* Start prediction task
108104
* @param request MLPredictionTaskRequest
109105
* @param listener Action listener
110106
*/
111-
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
107+
@Override
108+
protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
112109
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
113110
Instant now = Instant.now();
114111
MLTask mlTask = MLTask
@@ -166,36 +163,49 @@ private void predict(
166163
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
167164
return;
168165
}
169-
Map<String, Object> source = r.getSourceAsMap();
170-
User requestUser = getUserContext(client);
171-
User resourceUser = User.parse((String) source.get(USER));
172-
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
173-
// The backend roles of request user and resource user doesn't have intersection
174-
OpenSearchException e = new OpenSearchException(
175-
"User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId()
176-
);
177-
handlePredictFailure(mlTask, internalListener, e, false);
178-
return;
179-
}
166+
try (
167+
XContentParser xContentParser = XContentType.JSON
168+
.xContent()
169+
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString())
170+
) {
171+
ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser);
172+
MLModel mlModel = MLModel.parse(xContentParser);
173+
User resourceUser = mlModel.getUser();
174+
User requestUser = getUserContext(client);
175+
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
176+
// The backend roles of request user and resource user doesn't have intersection
177+
OpenSearchException e = new OpenSearchException(
178+
"User: "
179+
+ requestUser.getName()
180+
+ " does not have permissions to run predict by model: "
181+
+ request.getModelId()
182+
);
183+
handlePredictFailure(mlTask, internalListener, e, false);
184+
return;
185+
}
186+
Model model = new Model();
187+
model.setName(mlModel.getName());
188+
model.setVersion(mlModel.getVersion());
189+
byte[] decoded = Base64.getDecoder().decode(mlModel.getContent());
190+
model.setContent(decoded);
191+
192+
// run predict
193+
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
194+
MLOutput output = MLEngine
195+
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
196+
if (output instanceof MLPredictionOutput) {
197+
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
198+
}
180199

181-
Model model = new Model();
182-
model.setName((String) source.get(MLModel.MODEL_NAME));
183-
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
184-
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
185-
model.setContent(decoded);
186-
187-
// run predict
188-
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
189-
MLOutput output = MLEngine
190-
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
191-
if (output instanceof MLPredictionOutput) {
192-
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
200+
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
201+
handleAsyncMLTaskComplete(mlTask);
202+
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
203+
internalListener.onResponse(response);
204+
} catch (Exception e) {
205+
log.error("Failed to predict model " + request.getModelId(), e);
206+
internalListener.onFailure(e);
193207
}
194208

195-
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
196-
handleAsyncMLTaskComplete(mlTask);
197-
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
198-
internalListener.onResponse(response);
199209
}, e -> {
200210
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
201211
handlePredictFailure(mlTask, internalListener, e, true);

0 commit comments

Comments
 (0)