Skip to content

Commit 2ba85af

Browse files
add tasks API in Client (#200) (#202)
Signed-off-by: Xun Zhang <[email protected]> (cherry picked from commit d03208b) Co-authored-by: Xun Zhang <[email protected]>
1 parent e4af194 commit 2ba85af

File tree

5 files changed

+107
-20
lines changed

5 files changed

+107
-20
lines changed

client/build.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ jacocoTestCoverageVerification {
3535
rule {
3636
limit {
3737
counter = 'LINE'
38-
minimum = 0.4 //TODO: add more test to increase coverage to 0.9
38+
minimum = 0.3 //TODO: add more test to increase coverage to 0.9
3939
}
4040
limit {
4141
counter = 'BRANCH'
42-
minimum = 0.4 //TODO: add more test to increase coverage to 0.9
42+
minimum = 0.3 //TODO: add more test to increase coverage to 0.9
4343
}
4444
}
4545
}

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
import org.opensearch.action.search.SearchRequest;
1313
import org.opensearch.action.search.SearchResponse;
1414
import org.opensearch.action.support.PlainActionFuture;
15-
import org.opensearch.ml.common.parameter.Input;
16-
import org.opensearch.ml.common.parameter.MLInput;
17-
import org.opensearch.ml.common.parameter.MLOutput;
18-
import org.opensearch.ml.common.parameter.Output;
19-
import org.opensearch.ml.common.parameter.MLModel;
15+
import org.opensearch.ml.common.parameter.*;
2016

2117
/**
2218
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
@@ -118,6 +114,24 @@ default ActionFuture<MLModel> getModel(String modelId) {
118114
*/
119115
void getModel(String modelId, ActionListener<MLModel> listener);
120116

117+
/**
118+
* Get MLTask and return ActionFuture.
119+
* @param taskId id of the task
120+
* @return ActionFuture of ml task
121+
*/
122+
default ActionFuture<MLTask> getTask(String taskId) {
123+
PlainActionFuture<MLTask> actionFuture = PlainActionFuture.newFuture();
124+
getTask(taskId, actionFuture);
125+
return actionFuture;
126+
}
127+
128+
/**
129+
* Get MLTask and return task in listener
130+
* @param taskId id of the model
131+
* @param listener action listener
132+
*/
133+
void getTask(String taskId, ActionListener<MLTask> listener);
134+
121135
/**
122136
* Delete the model with modelId.
123137
* @param modelId ML model id
@@ -136,6 +150,24 @@ default ActionFuture<DeleteResponse> deleteModel(String modelId) {
136150
*/
137151
void deleteModel(String modelId, ActionListener<DeleteResponse> listener);
138152

153+
/**
154+
* Delete the task with taskId.
155+
* @param taskId ML task id
156+
* @return the result future
157+
*/
158+
default ActionFuture<DeleteResponse> deleteTask(String taskId) {
159+
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
160+
deleteModel(taskId, actionFuture);
161+
return actionFuture;
162+
}
163+
164+
/**
165+
* Delete MLTask
166+
* @param taskId id of the task
167+
* @param listener action listener
168+
*/
169+
void deleteTask(String taskId, ActionListener<DeleteResponse> listener);
170+
139171
/**
140172
*
141173
* @param searchRequest searchRequest to search the ML Model
@@ -146,10 +178,29 @@ default ActionFuture<SearchResponse> searchModel(SearchRequest searchRequest) {
146178
searchModel(searchRequest, actionFuture);
147179
return actionFuture;
148180
}
181+
149182
/**
150183
*
151184
* @param searchRequest searchRequest to search the ML Model
152185
* @param listener action listener
153186
*/
154187
void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener);
188+
189+
/**
190+
*
191+
* @param searchRequest searchRequest to search the ML Task
192+
* @return Action future of search response
193+
*/
194+
default ActionFuture<SearchResponse> searchTask(SearchRequest searchRequest) {
195+
PlainActionFuture<SearchResponse> actionFuture = PlainActionFuture.newFuture();
196+
searchTask(searchRequest, actionFuture);
197+
return actionFuture;
198+
}
199+
200+
/**
201+
*
202+
* @param searchRequest searchRequest to search the ML Task
203+
* @param listener action listener
204+
*/
205+
void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener);
155206
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
import org.opensearch.action.search.SearchRequest;
1515
import org.opensearch.action.search.SearchResponse;
1616
import org.opensearch.client.node.NodeClient;
17-
import org.opensearch.ml.common.parameter.Input;
18-
import org.opensearch.ml.common.parameter.MLInput;
19-
import org.opensearch.ml.common.parameter.MLOutput;
20-
import org.opensearch.ml.common.parameter.Output;
21-
import org.opensearch.ml.common.parameter.MLModel;
17+
import org.opensearch.ml.common.parameter.*;
2218
import org.opensearch.ml.common.transport.MLTaskResponse;
2319
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
2420
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
@@ -31,6 +27,7 @@
3127
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
3228
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3329
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
30+
import org.opensearch.ml.common.transport.task.*;
3431
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
3532
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
3633
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
@@ -117,6 +114,35 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
117114
}, listener::onFailure));
118115
}
119116

117+
@Override
118+
public void getTask(String taskId, ActionListener<MLTask> listener) {
119+
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder()
120+
.taskId(taskId)
121+
.build();
122+
123+
client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(response -> {
124+
listener.onResponse(MLTaskGetResponse.fromActionResponse(response).getMlTask());
125+
}, listener::onFailure));
126+
}
127+
128+
@Override
129+
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
130+
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder()
131+
.taskId(taskId)
132+
.build();
133+
134+
client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> {
135+
listener.onResponse(deleteResponse);
136+
}, listener::onFailure));
137+
}
138+
139+
@Override
140+
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
141+
client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
142+
listener.onResponse(searchResponse);
143+
}, listener::onFailure));
144+
}
145+
120146
private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
121147
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
122148
listener.onResponse(predictionResponse.getOutput());

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,7 @@
1919
import org.opensearch.common.xcontent.XContentBuilder;
2020
import org.opensearch.ml.common.dataframe.DataFrame;
2121
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
22-
import org.opensearch.ml.common.parameter.Input;
23-
import org.opensearch.ml.common.parameter.FunctionName;
24-
import org.opensearch.ml.common.parameter.MLAlgoParams;
25-
import org.opensearch.ml.common.parameter.MLInput;
26-
import org.opensearch.ml.common.parameter.MLOutput;
27-
import org.opensearch.ml.common.parameter.MLTrainingOutput;
28-
import org.opensearch.ml.common.parameter.Output;
29-
import org.opensearch.ml.common.parameter.MLModel;
22+
import org.opensearch.ml.common.parameter.*;
3023

3124
import java.io.IOException;
3225

@@ -111,6 +104,21 @@ public void deleteModel(String modelId, ActionListener<DeleteResponse> listener)
111104
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
112105
listener.onResponse(searchResponse);
113106
}
107+
108+
@Override
109+
public void getTask(String taskId, ActionListener<MLTask> listener) {
110+
listener.onResponse(MLTask.builder().build());
111+
}
112+
113+
@Override
114+
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
115+
listener.onResponse(deleteResponse);
116+
}
117+
118+
@Override
119+
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
120+
listener.onResponse(searchResponse);
121+
}
114122
};
115123
}
116124

common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.common.transport.task;
77

88
import lombok.Builder;
9+
import lombok.Getter;
910
import org.opensearch.action.ActionResponse;
1011
import org.opensearch.common.io.stream.InputStreamStreamInput;
1112
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
@@ -20,6 +21,7 @@
2021
import java.io.IOException;
2122
import java.io.UncheckedIOException;
2223

24+
@Getter
2325
public class MLTaskGetResponse extends ActionResponse implements ToXContentObject {
2426
MLTask mlTask;
2527

0 commit comments

Comments
 (0)