Skip to content

Commit 8c544c7

Browse files
add more UT to client module (#203) (#205)
Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit a4305a5) Co-authored-by: Yaliang Wu <[email protected]>
1 parent 2ba85af commit 8c544c7

File tree

5 files changed

+327
-57
lines changed

5 files changed

+327
-57
lines changed

client/build.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ jacocoTestCoverageVerification {
3535
rule {
3636
limit {
3737
counter = 'LINE'
38-
minimum = 0.3 //TODO: add more test to increase coverage to 0.9
38+
minimum = 0.8
3939
}
4040
limit {
4141
counter = 'BRANCH'
42-
minimum = 0.3 //TODO: add more test to increase coverage to 0.9
42+
minimum = 0.8
4343
}
4444
}
4545
}

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,6 @@ default ActionFuture<MLOutput> train(MLInput mlInput, boolean asyncTask) {
7878
*/
7979
void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> listener);
8080

81-
/**
82-
* Execute function and return ActionFuture.
83-
* @param input input data
84-
* @return ActionFuture of output
85-
*/
86-
default ActionFuture<Output> execute(Input input) {
87-
PlainActionFuture<Output> actionFuture = PlainActionFuture.newFuture();
88-
execute(input, actionFuture);
89-
return actionFuture;
90-
}
91-
92-
/**
93-
* Execute function and return output in listener
94-
* @param input input data
95-
* @param listener action listener
96-
*/
97-
void execute(Input input, ActionListener<Output> listener);
98-
9981
/**
10082
* Get MLModel and return ActionFuture.
10183
* @param modelId id of the model

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
import org.opensearch.client.node.NodeClient;
1717
import org.opensearch.ml.common.parameter.*;
1818
import org.opensearch.ml.common.transport.MLTaskResponse;
19-
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
20-
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
21-
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
2219
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
2320
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
2421
import org.opensearch.ml.common.transport.model.MLModelGetAction;
@@ -74,17 +71,6 @@ public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> l
7471
client.execute(MLTrainingTaskAction.INSTANCE, trainingTaskRequest, getMlPredictionTaskResponseActionListener(listener));
7572
}
7673

77-
@Override
78-
public void execute(Input input, ActionListener<Output> listener) {
79-
MLExecuteTaskRequest executeTaskRequest = MLExecuteTaskRequest.builder()
80-
.input(input)
81-
.build();
82-
83-
client.execute(MLExecuteTaskAction.INSTANCE, executeTaskRequest, ActionListener.wrap(response -> {
84-
listener.onResponse(MLExecuteTaskResponse.fromActionResponse(response).getOutput());
85-
}, listener::onFailure));
86-
}
87-
8874
@Override
8975
public void getModel(String modelId, ActionListener<MLModel> listener) {
9076
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder()

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,26 @@ public class MachineLearningClientTest {
5050
SearchResponse searchResponse;
5151

5252
private String modekId = "test_model_id";
53+
private MLModel mlModel;
54+
private MLTask mlTask;
55+
5356
@Before
54-
public void setUp() throws Exception {
57+
public void setUp() {
5558
MockitoAnnotations.openMocks(this);
59+
String taskId = "taskId";
60+
String modelId = "modelId";
61+
mlTask = MLTask.builder()
62+
.taskId(taskId)
63+
.modelId(modelId)
64+
.functionName(FunctionName.KMEANS)
65+
.build();
66+
67+
String modelContent = "test content";
68+
mlModel = MLModel.builder()
69+
.algorithm(FunctionName.KMEANS)
70+
.name("test")
71+
.content(modelContent)
72+
.build();
5673

5774
machineLearningClient = new MachineLearningClient() {
5875
@Override
@@ -72,27 +89,9 @@ public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> l
7289
listener.onResponse(MLTrainingOutput.builder().modelId(modekId).build());
7390
}
7491

75-
@Override
76-
public void execute(Input input, ActionListener<Output> listener) {
77-
listener.onResponse(new Output() {
78-
@Override
79-
public void writeTo(StreamOutput out) {
80-
81-
}
82-
83-
@Override
84-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
85-
builder.startObject();
86-
builder.field("test", "test_value");
87-
builder.endObject();
88-
return builder;
89-
}
90-
});
91-
}
92-
9392
@Override
9493
public void getModel(String modelId, ActionListener<MLModel> listener) {
95-
listener.onResponse(MLModel.builder().build());
94+
listener.onResponse(mlModel);
9695
}
9796

9897
@Override
@@ -107,7 +106,7 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
107106

108107
@Override
109108
public void getTask(String taskId, ActionListener<MLTask> listener) {
110-
listener.onResponse(MLTask.builder().build());
109+
listener.onResponse(mlTask);
111110
}
112111

113112
@Override
@@ -185,4 +184,44 @@ public void train() {
185184
.build();
186185
assertEquals(modekId, ((MLTrainingOutput)machineLearningClient.train(mlInput, false).actionGet()).getModelId());
187186
}
188-
}
187+
188+
@Test
189+
public void trainAndPredict() {
190+
MLInput mlInput = MLInput.builder()
191+
.algorithm(FunctionName.KMEANS)
192+
.parameters(mlParameters)
193+
.inputDataset(new DataFrameInputDataset(input))
194+
.build();
195+
assertEquals(output, machineLearningClient.trainAndPredict(mlInput).actionGet());
196+
}
197+
198+
@Test
199+
public void getModel() {
200+
assertEquals(mlModel, machineLearningClient.getModel("modelId").actionGet());
201+
}
202+
203+
@Test
204+
public void deleteModel() {
205+
assertEquals(deleteResponse, machineLearningClient.deleteModel("modelId").actionGet());
206+
}
207+
208+
@Test
209+
public void searchModel() {
210+
assertEquals(searchResponse, machineLearningClient.searchModel(new SearchRequest()).actionGet());
211+
}
212+
213+
@Test
214+
public void getTask() {
215+
assertEquals(mlTask, machineLearningClient.getTask("taskId").actionGet());
216+
}
217+
218+
@Test
219+
public void deleteTask() {
220+
assertEquals(deleteResponse, machineLearningClient.deleteTask("taskId").actionGet());
221+
}
222+
223+
@Test
224+
public void searchTask() {
225+
assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet());
226+
}
227+
}

0 commit comments

Comments
 (0)