Skip to content

Commit c1728a1

Browse files
authored
add more UT for task manager/runner (#206) (#207)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 8c544c7 commit c1728a1

File tree

7 files changed

+846
-13
lines changed

7 files changed

+846
-13
lines changed

common/src/main/java/org/opensearch/ml/common/parameter/MLTask.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public class MLTask implements ToXContentObject, Writeable {
6161
private User user; // TODO: support document level access control later
6262
private boolean async;
6363

64-
@Builder
64+
@Builder(toBuilder = true)
6565
public MLTask(
6666
String taskId,
6767
String modelId,

plugin/build.gradle

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,20 @@ List<String> jacocoExclusions = [
209209
'org.opensearch.ml.indices.MLInputDatasetHandler',
210210
'org.opensearch.ml.plugin.*',
211211
'org.opensearch.ml.task.MLTaskDispatcher',
212-
'org.opensearch.ml.task.MLTaskRunner',
213-
'org.opensearch.ml.task.MLTrainingTaskRunner',
214212
'org.opensearch.ml.task.MLPredictTaskRunner',
215213
'org.opensearch.ml.rest.RestMLTrainingAction',
216214
'org.opensearch.ml.rest.RestMLPredictionAction',
217215
'org.opensearch.ml.utils.RestActionUtils',
218216
'org.opensearch.ml.task.MLTaskCache',
219-
'org.opensearch.ml.task.MLTaskManager',
220-
'org.opensearch.ml.task.MLTrainAndPredictTaskRunner',
221-
'org.opensearch.ml.rest.RestMLGetModelAction',
222-
'org.opensearch.ml.rest.RestMLDeleteModelAction',
223-
'org.opensearch.ml.rest.*'
217+
'org.opensearch.ml.rest.AbstractMLSearchAction*',
218+
'org.opensearch.ml.utils.MLNodeUtils', //0.5
219+
'org.opensearch.ml.task.MLExecuteTaskRunner', //0.5
220+
'org.opensearch.ml.rest.RestMLDeleteTaskAction', //0.5
221+
'org.opensearch.ml.rest.RestMLGetModelAction', //0.5
222+
'org.opensearch.ml.rest.RestMLExecuteAction', //0.3
223+
'org.opensearch.ml.rest.RestMLDeleteModelAction', //0.5
224+
'org.opensearch.ml.rest.RestMLTrainAndPredictAction', //0.3
225+
'org.opensearch.ml.rest.RestMLGetTaskAction' //0.5
224226
]
225227

226228
jacocoTestCoverageVerification {
@@ -239,7 +241,7 @@ jacocoTestCoverageVerification {
239241
limit {
240242
counter = 'LINE'
241243
value = 'COVEREDRATIO'
242-
minimum = 0.3 //TODO: add more test to meet the coverage bar 0.7
244+
minimum = 0.7
243245
}
244246
}
245247
}

plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,29 @@
1717
import org.junit.rules.ExpectedException;
1818
import org.mockito.ArgumentCaptor;
1919
import org.opensearch.action.ActionListener;
20+
import org.opensearch.action.DocWriteResponse;
21+
import org.opensearch.action.update.UpdateRequest;
2022
import org.opensearch.action.update.UpdateResponse;
2123
import org.opensearch.client.Client;
24+
import org.opensearch.common.settings.Settings;
25+
import org.opensearch.common.util.concurrent.ThreadContext;
26+
import org.opensearch.index.Index;
27+
import org.opensearch.index.shard.ShardId;
2228
import org.opensearch.ml.common.parameter.MLTask;
2329
import org.opensearch.ml.common.parameter.MLTaskState;
2430
import org.opensearch.ml.common.parameter.MLTaskType;
2531
import org.opensearch.ml.indices.MLIndicesHandler;
2632
import org.opensearch.test.OpenSearchTestCase;
33+
import org.opensearch.threadpool.ThreadPool;
34+
35+
import com.google.common.collect.ImmutableMap;
2736

2837
public class MLTaskManagerTests extends OpenSearchTestCase {
2938
MLTaskManager mlTaskManager;
3039
MLTask mlTask;
3140
Client client;
41+
ThreadPool threadPool;
42+
ThreadContext threadContext;
3243
MLIndicesHandler mlIndicesHandler;
3344

3445
@Rule
@@ -37,6 +48,12 @@ public class MLTaskManagerTests extends OpenSearchTestCase {
3748
@Before
3849
public void setup() {
3950
this.client = mock(Client.class);
51+
this.threadPool = mock(ThreadPool.class);
52+
Settings settings = Settings.builder().build();
53+
threadContext = new ThreadContext(settings);
54+
when(client.threadPool()).thenReturn(threadPool);
55+
when(threadPool.getThreadContext()).thenReturn(threadContext);
56+
4057
this.mlIndicesHandler = mock(MLIndicesHandler.class);
4158
this.mlTaskManager = spy(new MLTaskManager(client, mlIndicesHandler));
4259
this.mlTask = MLTask
@@ -89,34 +106,126 @@ public void testUpdateTaskStateAndError() {
89106
Assert.assertEquals(0, value.longValue());
90107
}
91108

109+
public void testUpdateTaskStateAndError_SyncTask() {
110+
mlTaskManager.add(mlTask);
111+
mlTaskManager.updateTaskStateAndError(mlTask.getTaskId(), MLTaskState.FAILED, "test error", false);
112+
verify(mlTaskManager, never()).updateMLTask(eq(mlTask.getTaskId()), any(), anyLong());
113+
}
114+
92115
public void testUpdateMLTaskWithNullOrEmptyMap() {
93116
mlTaskManager.add(mlTask);
94117
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
95118
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0);
96-
verify(client, never()).index(any());
119+
verify(client, never()).update(any(), any());
97120
verify(listener, times(1)).onFailure(any());
98121

99122
mlTaskManager.updateMLTask(mlTask.getTaskId(), new HashMap<>(), listener, 0);
100-
verify(client, never()).index(any());
123+
verify(client, never()).update(any(), any());
101124
verify(listener, times(2)).onFailure(any());
102125
}
103126

127+
public void testUpdateMLTask_NonExistingTask() {
128+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
129+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
130+
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0);
131+
verify(client, never()).update(any(), any());
132+
verify(listener, times(1)).onFailure(argumentCaptor.capture());
133+
assertEquals("Can't find task", argumentCaptor.getValue().getMessage());
134+
}
135+
136+
public void testUpdateMLTask_NoSemaphore() {
137+
MLTask asyncMlTask = mlTask.toBuilder().async(true).build();
138+
mlTaskManager.add(asyncMlTask);
139+
140+
doAnswer(invocation -> {
141+
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
142+
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
143+
UpdateResponse output = new UpdateResponse(shardId, "_doc", "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED);
144+
actionListener.onResponse(output);
145+
return null;
146+
}).when(client).update(any(UpdateRequest.class), any());
147+
148+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
149+
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), ActionListener.wrap(r -> {
150+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
151+
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), null, listener, 0);
152+
verify(client, times(1)).update(any(), any());
153+
verify(listener, times(1)).onFailure(argumentCaptor.capture());
154+
assertEquals("Other updating request not finished yet", argumentCaptor.getValue().getMessage());
155+
}, e -> { assertNull(e); }), 0);
156+
}
157+
158+
public void testUpdateMLTask_FailedToUpdate() {
159+
MLTask asyncMlTask = mlTask.toBuilder().async(true).build();
160+
mlTaskManager.add(asyncMlTask);
161+
162+
String errorMessage = "test error message";
163+
doAnswer(invocation -> {
164+
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
165+
actionListener.onFailure(new RuntimeException(errorMessage));
166+
return null;
167+
}).when(client).update(any(UpdateRequest.class), any());
168+
169+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
170+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
171+
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0);
172+
verify(client, times(1)).update(any(), any());
173+
verify(listener, times(1)).onFailure(argumentCaptor.capture());
174+
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
175+
}
176+
177+
public void testUpdateMLTask_ThrowException() {
178+
MLTask asyncMlTask = mlTask.toBuilder().async(true).build();
179+
mlTaskManager.add(asyncMlTask);
180+
181+
String errorMessage = "test error message";
182+
doThrow(new RuntimeException(errorMessage)).when(client).update(any(UpdateRequest.class), any());
183+
184+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
185+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
186+
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0);
187+
verify(client, times(1)).update(any(), any());
188+
verify(listener, times(1)).onFailure(argumentCaptor.capture());
189+
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
190+
}
191+
104192
public void testRemove() {
105193
mlTaskManager.add(mlTask);
106194
Assert.assertTrue(mlTaskManager.contains(mlTask.getTaskId()));
107195
mlTaskManager.remove(mlTask.getTaskId());
108196
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
109197
}
110198

199+
public void testRemove_NonExistingTask() {
200+
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
201+
mlTaskManager.remove(mlTask.getTaskId());
202+
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
203+
}
204+
205+
public void testGetTask() {
206+
mlTaskManager.add(mlTask);
207+
Assert.assertTrue(mlTaskManager.contains(mlTask.getTaskId()));
208+
MLTask task = mlTaskManager.get(this.mlTask.getTaskId());
209+
Assert.assertEquals(mlTask, task);
210+
}
211+
212+
public void testGetTask_NonExisting() {
213+
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
214+
MLTask task = mlTaskManager.get(this.mlTask.getTaskId());
215+
Assert.assertNull(task);
216+
}
217+
111218
public void testGetRunningTaskCount() {
112219
MLTask task1 = MLTask.builder().taskId("1").state(MLTaskState.CREATED).build();
113220
MLTask task2 = MLTask.builder().taskId("2").state(MLTaskState.RUNNING).build();
114221
MLTask task3 = MLTask.builder().taskId("3").state(MLTaskState.FAILED).build();
115222
MLTask task4 = MLTask.builder().taskId("4").state(MLTaskState.COMPLETED).build();
223+
MLTask task5 = MLTask.builder().taskId("5").state(null).build();
116224
mlTaskManager.add(task1);
117225
mlTaskManager.add(task2);
118226
mlTaskManager.add(task3);
119227
mlTaskManager.add(task4);
228+
mlTaskManager.add(task5);
120229
Assert.assertEquals(mlTaskManager.getRunningTaskCount(), 1);
121230
}
122231

@@ -155,9 +264,29 @@ public void testCreateMlTask_IndexException() {
155264
return null;
156265
}).when(mlIndicesHandler).initMLTaskIndex(any(ActionListener.class));
157266

158-
doThrow(new RuntimeException("test")).when(client).index(any(), any());
267+
String errorMessage = "test error message";
268+
doThrow(new RuntimeException(errorMessage)).when(client).index(any(), any());
159269
ActionListener listener = mock(ActionListener.class);
270+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
160271
mlTaskManager.createMLTask(mlTask, listener);
161-
verify(listener).onFailure(any());
272+
verify(listener).onFailure(argumentCaptor.capture());
273+
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
274+
}
275+
276+
public void testCreateMlTask_FailToGetThreadPool() {
277+
doAnswer(invocation -> {
278+
ActionListener<Boolean> listener = invocation.getArgument(0);
279+
listener.onResponse(true);
280+
return null;
281+
}).when(mlIndicesHandler).initMLTaskIndex(any(ActionListener.class));
282+
283+
String errorMessage = "test error message";
284+
doThrow(new RuntimeException(errorMessage)).when(threadPool).getThreadContext();
285+
ActionListener listener = mock(ActionListener.class);
286+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
287+
mlTaskManager.createMLTask(mlTask, listener);
288+
verify(listener).onFailure(argumentCaptor.capture());
289+
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
162290
}
291+
163292
}

0 commit comments

Comments
 (0)