1717import org .junit .rules .ExpectedException ;
1818import org .mockito .ArgumentCaptor ;
1919import org .opensearch .action .ActionListener ;
20+ import org .opensearch .action .DocWriteResponse ;
21+ import org .opensearch .action .update .UpdateRequest ;
2022import org .opensearch .action .update .UpdateResponse ;
2123import org .opensearch .client .Client ;
24+ import org .opensearch .common .settings .Settings ;
25+ import org .opensearch .common .util .concurrent .ThreadContext ;
26+ import org .opensearch .index .Index ;
27+ import org .opensearch .index .shard .ShardId ;
2228import org .opensearch .ml .common .parameter .MLTask ;
2329import org .opensearch .ml .common .parameter .MLTaskState ;
2430import org .opensearch .ml .common .parameter .MLTaskType ;
2531import org .opensearch .ml .indices .MLIndicesHandler ;
2632import org .opensearch .test .OpenSearchTestCase ;
33+ import org .opensearch .threadpool .ThreadPool ;
34+
35+ import com .google .common .collect .ImmutableMap ;
2736
2837public class MLTaskManagerTests extends OpenSearchTestCase {
2938 MLTaskManager mlTaskManager ;
3039 MLTask mlTask ;
3140 Client client ;
41+ ThreadPool threadPool ;
42+ ThreadContext threadContext ;
3243 MLIndicesHandler mlIndicesHandler ;
3344
3445 @ Rule
@@ -37,6 +48,12 @@ public class MLTaskManagerTests extends OpenSearchTestCase {
3748 @ Before
3849 public void setup () {
3950 this .client = mock (Client .class );
51+ this .threadPool = mock (ThreadPool .class );
52+ Settings settings = Settings .builder ().build ();
53+ threadContext = new ThreadContext (settings );
54+ when (client .threadPool ()).thenReturn (threadPool );
55+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
56+
4057 this .mlIndicesHandler = mock (MLIndicesHandler .class );
4158 this .mlTaskManager = spy (new MLTaskManager (client , mlIndicesHandler ));
4259 this .mlTask = MLTask
@@ -89,34 +106,126 @@ public void testUpdateTaskStateAndError() {
89106 Assert .assertEquals (0 , value .longValue ());
90107 }
91108
109+ public void testUpdateTaskStateAndError_SyncTask () {
110+ mlTaskManager .add (mlTask );
111+ mlTaskManager .updateTaskStateAndError (mlTask .getTaskId (), MLTaskState .FAILED , "test error" , false );
112+ verify (mlTaskManager , never ()).updateMLTask (eq (mlTask .getTaskId ()), any (), anyLong ());
113+ }
114+
92115 public void testUpdateMLTaskWithNullOrEmptyMap () {
93116 mlTaskManager .add (mlTask );
94117 ActionListener <UpdateResponse > listener = mock (ActionListener .class );
95118 mlTaskManager .updateMLTask (mlTask .getTaskId (), null , listener , 0 );
96- verify (client , never ()).index ( any ());
119+ verify (client , never ()).update ( any (), any ());
97120 verify (listener , times (1 )).onFailure (any ());
98121
99122 mlTaskManager .updateMLTask (mlTask .getTaskId (), new HashMap <>(), listener , 0 );
100- verify (client , never ()).index ( any ());
123+ verify (client , never ()).update ( any (), any ());
101124 verify (listener , times (2 )).onFailure (any ());
102125 }
103126
127+ public void testUpdateMLTask_NonExistingTask () {
128+ ActionListener <UpdateResponse > listener = mock (ActionListener .class );
129+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
130+ mlTaskManager .updateMLTask (mlTask .getTaskId (), null , listener , 0 );
131+ verify (client , never ()).update (any (), any ());
132+ verify (listener , times (1 )).onFailure (argumentCaptor .capture ());
133+ assertEquals ("Can't find task" , argumentCaptor .getValue ().getMessage ());
134+ }
135+
136+ public void testUpdateMLTask_NoSemaphore () {
137+ MLTask asyncMlTask = mlTask .toBuilder ().async (true ).build ();
138+ mlTaskManager .add (asyncMlTask );
139+
140+ doAnswer (invocation -> {
141+ ActionListener <UpdateResponse > actionListener = invocation .getArgument (1 );
142+ ShardId shardId = new ShardId (new Index ("indexName" , "uuid" ), 1 );
143+ UpdateResponse output = new UpdateResponse (shardId , "_doc" , "taskId" , 1 , 1 , 1 , DocWriteResponse .Result .CREATED );
144+ actionListener .onResponse (output );
145+ return null ;
146+ }).when (client ).update (any (UpdateRequest .class ), any ());
147+
148+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
149+ mlTaskManager .updateMLTask (asyncMlTask .getTaskId (), ImmutableMap .of (MLTask .ERROR_FIELD , "test error" ), ActionListener .wrap (r -> {
150+ ActionListener <UpdateResponse > listener = mock (ActionListener .class );
151+ mlTaskManager .updateMLTask (asyncMlTask .getTaskId (), null , listener , 0 );
152+ verify (client , times (1 )).update (any (), any ());
153+ verify (listener , times (1 )).onFailure (argumentCaptor .capture ());
154+ assertEquals ("Other updating request not finished yet" , argumentCaptor .getValue ().getMessage ());
155+ }, e -> { assertNull (e ); }), 0 );
156+ }
157+
158+ public void testUpdateMLTask_FailedToUpdate () {
159+ MLTask asyncMlTask = mlTask .toBuilder ().async (true ).build ();
160+ mlTaskManager .add (asyncMlTask );
161+
162+ String errorMessage = "test error message" ;
163+ doAnswer (invocation -> {
164+ ActionListener <UpdateResponse > actionListener = invocation .getArgument (1 );
165+ actionListener .onFailure (new RuntimeException (errorMessage ));
166+ return null ;
167+ }).when (client ).update (any (UpdateRequest .class ), any ());
168+
169+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
170+ ActionListener <UpdateResponse > listener = mock (ActionListener .class );
171+ mlTaskManager .updateMLTask (asyncMlTask .getTaskId (), ImmutableMap .of (MLTask .ERROR_FIELD , "test error" ), listener , 0 );
172+ verify (client , times (1 )).update (any (), any ());
173+ verify (listener , times (1 )).onFailure (argumentCaptor .capture ());
174+ assertEquals (errorMessage , argumentCaptor .getValue ().getMessage ());
175+ }
176+
177+ public void testUpdateMLTask_ThrowException () {
178+ MLTask asyncMlTask = mlTask .toBuilder ().async (true ).build ();
179+ mlTaskManager .add (asyncMlTask );
180+
181+ String errorMessage = "test error message" ;
182+ doThrow (new RuntimeException (errorMessage )).when (client ).update (any (UpdateRequest .class ), any ());
183+
184+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
185+ ActionListener <UpdateResponse > listener = mock (ActionListener .class );
186+ mlTaskManager .updateMLTask (asyncMlTask .getTaskId (), ImmutableMap .of (MLTask .ERROR_FIELD , "test error" ), listener , 0 );
187+ verify (client , times (1 )).update (any (), any ());
188+ verify (listener , times (1 )).onFailure (argumentCaptor .capture ());
189+ assertEquals (errorMessage , argumentCaptor .getValue ().getMessage ());
190+ }
191+
104192 public void testRemove () {
105193 mlTaskManager .add (mlTask );
106194 Assert .assertTrue (mlTaskManager .contains (mlTask .getTaskId ()));
107195 mlTaskManager .remove (mlTask .getTaskId ());
108196 Assert .assertFalse (mlTaskManager .contains (mlTask .getTaskId ()));
109197 }
110198
199+ public void testRemove_NonExistingTask () {
200+ Assert .assertFalse (mlTaskManager .contains (mlTask .getTaskId ()));
201+ mlTaskManager .remove (mlTask .getTaskId ());
202+ Assert .assertFalse (mlTaskManager .contains (mlTask .getTaskId ()));
203+ }
204+
205+ public void testGetTask () {
206+ mlTaskManager .add (mlTask );
207+ Assert .assertTrue (mlTaskManager .contains (mlTask .getTaskId ()));
208+ MLTask task = mlTaskManager .get (this .mlTask .getTaskId ());
209+ Assert .assertEquals (mlTask , task );
210+ }
211+
212+ public void testGetTask_NonExisting () {
213+ Assert .assertFalse (mlTaskManager .contains (mlTask .getTaskId ()));
214+ MLTask task = mlTaskManager .get (this .mlTask .getTaskId ());
215+ Assert .assertNull (task );
216+ }
217+
111218 public void testGetRunningTaskCount () {
112219 MLTask task1 = MLTask .builder ().taskId ("1" ).state (MLTaskState .CREATED ).build ();
113220 MLTask task2 = MLTask .builder ().taskId ("2" ).state (MLTaskState .RUNNING ).build ();
114221 MLTask task3 = MLTask .builder ().taskId ("3" ).state (MLTaskState .FAILED ).build ();
115222 MLTask task4 = MLTask .builder ().taskId ("4" ).state (MLTaskState .COMPLETED ).build ();
223+ MLTask task5 = MLTask .builder ().taskId ("5" ).state (null ).build ();
116224 mlTaskManager .add (task1 );
117225 mlTaskManager .add (task2 );
118226 mlTaskManager .add (task3 );
119227 mlTaskManager .add (task4 );
228+ mlTaskManager .add (task5 );
120229 Assert .assertEquals (mlTaskManager .getRunningTaskCount (), 1 );
121230 }
122231
@@ -155,9 +264,29 @@ public void testCreateMlTask_IndexException() {
155264 return null ;
156265 }).when (mlIndicesHandler ).initMLTaskIndex (any (ActionListener .class ));
157266
158- doThrow (new RuntimeException ("test" )).when (client ).index (any (), any ());
267+ String errorMessage = "test error message" ;
268+ doThrow (new RuntimeException (errorMessage )).when (client ).index (any (), any ());
159269 ActionListener listener = mock (ActionListener .class );
270+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
160271 mlTaskManager .createMLTask (mlTask , listener );
161- verify (listener ).onFailure (any ());
272+ verify (listener ).onFailure (argumentCaptor .capture ());
273+ assertEquals (errorMessage , argumentCaptor .getValue ().getMessage ());
274+ }
275+
276+ public void testCreateMlTask_FailToGetThreadPool () {
277+ doAnswer (invocation -> {
278+ ActionListener <Boolean > listener = invocation .getArgument (0 );
279+ listener .onResponse (true );
280+ return null ;
281+ }).when (mlIndicesHandler ).initMLTaskIndex (any (ActionListener .class ));
282+
283+ String errorMessage = "test error message" ;
284+ doThrow (new RuntimeException (errorMessage )).when (threadPool ).getThreadContext ();
285+ ActionListener listener = mock (ActionListener .class );
286+ ArgumentCaptor <Exception > argumentCaptor = ArgumentCaptor .forClass (Exception .class );
287+ mlTaskManager .createMLTask (mlTask , listener );
288+ verify (listener ).onFailure (argumentCaptor .capture ());
289+ assertEquals (errorMessage , argumentCaptor .getValue ().getMessage ());
162290 }
291+
163292}
0 commit comments