Skip to content

Commit 5eb38f7

Browse files
authored
Add UT for ml rest execute action (#278)
Signed-off-by: Xun Zhang <[email protected]>
1 parent 0c6249a commit 5eb38f7

File tree

3 files changed

+122
-2
lines changed

3 files changed

+122
-2
lines changed

plugin/build.gradle

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ jacocoTestReport {
207207
List<String> jacocoExclusions = [
208208
// TODO: add more unit test to meet the minimal test coverage.
209209
'org.opensearch.ml.constant.CommonValue',
210-
'org.opensearch.ml.plugin.MachineLearningPlugin*',
211-
'org.opensearch.ml.rest.RestMLExecuteAction' //0.3
210+
'org.opensearch.ml.plugin.MachineLearningPlugin*'
212211
]
213212

214213
jacocoTestCoverageVerification {
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.rest;
7+
8+
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.ArgumentMatchers.eq;
10+
import static org.mockito.Mockito.doAnswer;
11+
import static org.mockito.Mockito.spy;
12+
import static org.mockito.Mockito.times;
13+
import static org.mockito.Mockito.verify;
14+
import static org.opensearch.ml.utils.TestHelper.getLocalSampleCalculatorRestRequest;
15+
16+
import java.io.IOException;
17+
import java.util.List;
18+
19+
import org.junit.Before;
20+
import org.mockito.ArgumentCaptor;
21+
import org.mockito.Mock;
22+
import org.mockito.MockitoAnnotations;
23+
import org.opensearch.action.ActionListener;
24+
import org.opensearch.client.node.NodeClient;
25+
import org.opensearch.common.Strings;
26+
import org.opensearch.common.settings.Settings;
27+
import org.opensearch.ml.common.FunctionName;
28+
import org.opensearch.ml.common.input.Input;
29+
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
30+
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
31+
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
32+
import org.opensearch.rest.RestChannel;
33+
import org.opensearch.rest.RestHandler;
34+
import org.opensearch.rest.RestRequest;
35+
import org.opensearch.test.OpenSearchTestCase;
36+
import org.opensearch.threadpool.TestThreadPool;
37+
import org.opensearch.threadpool.ThreadPool;
38+
39+
public class RestMLExecuteActionTests extends OpenSearchTestCase {
40+
41+
private RestMLExecuteAction restMLExecuteAction;
42+
43+
NodeClient client;
44+
private ThreadPool threadPool;
45+
46+
@Mock
47+
RestChannel channel;
48+
49+
@Before
50+
public void setup() {
51+
MockitoAnnotations.openMocks(this);
52+
restMLExecuteAction = new RestMLExecuteAction();
53+
54+
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
55+
client = spy(new NodeClient(Settings.EMPTY, threadPool));
56+
57+
doAnswer(invocation -> {
58+
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
59+
return null;
60+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
61+
}
62+
63+
@Override
64+
public void tearDown() throws Exception {
65+
super.tearDown();
66+
threadPool.shutdown();
67+
client.close();
68+
}
69+
70+
public void testConstructor() {
71+
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction();
72+
assertNotNull(restMLExecuteAction);
73+
}
74+
75+
public void testGetName() {
76+
String actionName = restMLExecuteAction.getName();
77+
assertFalse(Strings.isNullOrEmpty(actionName));
78+
assertEquals("ml_execute_action", actionName);
79+
}
80+
81+
public void testRoutes() {
82+
List<RestHandler.Route> routes = restMLExecuteAction.routes();
83+
assertNotNull(routes);
84+
assertFalse(routes.isEmpty());
85+
RestHandler.Route route = routes.get(0);
86+
assertEquals(RestRequest.Method.POST, route.getMethod());
87+
assertEquals("/_plugins/_ml/_execute/{algorithm}", route.getPath());
88+
}
89+
90+
public void testGetRequest() throws IOException {
91+
RestRequest request = getLocalSampleCalculatorRestRequest();
92+
MLExecuteTaskRequest executeTaskRequest = restMLExecuteAction.getRequest(request);
93+
94+
Input input = executeTaskRequest.getInput();
95+
assertNotNull(input);
96+
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
97+
}
98+
99+
public void testPrepareRequest() throws Exception {
100+
RestRequest request = getLocalSampleCalculatorRestRequest();
101+
restMLExecuteAction.handleRequest(request, channel, client);
102+
103+
ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
104+
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
105+
Input input = argumentCaptor.getValue().getInput();
106+
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
107+
}
108+
}

plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.opensearch.ml.common.dataset.MLInputDataType;
4646
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
4747
import org.opensearch.ml.common.input.MLInput;
48+
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
4849
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
4950
import org.opensearch.rest.RestRequest;
5051
import org.opensearch.rest.RestStatus;
@@ -164,6 +165,17 @@ public static RestRequest getKMeansRestRequest() {
164165
return request;
165166
}
166167

168+
public static RestRequest getLocalSampleCalculatorRestRequest() {
169+
Map<String, String> params = new HashMap<>();
170+
params.put(PARAMETER_ALGORITHM, FunctionName.LOCAL_SAMPLE_CALCULATOR.name());
171+
final String requestContent = "{\"operation\": \"max\",\"input_data\":[1.0, 2.0, 3.0]}";
172+
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
173+
.withParams(params)
174+
.withContent(new BytesArray(requestContent), XContentType.JSON)
175+
.build();
176+
return request;
177+
}
178+
167179
public static RestRequest getSearchAllRestRequest() {
168180
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
169181
.withContent(new BytesArray(TestData.matchAllSearchQuery()), XContentType.JSON)
@@ -186,6 +198,7 @@ private static NamedXContentRegistry getXContentRegistry() {
186198
List<NamedXContentRegistry.Entry> entries = new ArrayList<>();
187199
entries.addAll(searchModule.getNamedXContents());
188200
entries.add(KMeansParams.XCONTENT_REGISTRY);
201+
entries.add(LocalSampleCalculatorInput.XCONTENT_REGISTRY);
189202
return new NamedXContentRegistry(entries);
190203
}
191204
}

0 commit comments

Comments
 (0)