1313import static org .mockito .Mockito .doReturn ;
1414import static org .mockito .Mockito .doThrow ;
1515import static org .mockito .Mockito .mock ;
16+ import static org .mockito .Mockito .never ;
1617import static org .mockito .Mockito .spy ;
1718import static org .mockito .Mockito .verify ;
1819import static org .mockito .Mockito .when ;
20+ import static org .opensearch .ml .common .CommonValue .ML_MODEL_INDEX ;
1921import static org .opensearch .ml .task .MLPredictTaskRunnerTests .USER_STRING ;
2022
2123import java .io .IOException ;
2224import java .util .ArrayList ;
2325import java .util .List ;
26+ import java .util .Map ;
2427
2528import org .junit .Before ;
2629import org .junit .Rule ;
2932import org .mockito .Mock ;
3033import org .mockito .MockitoAnnotations ;
3134import org .opensearch .action .FailedNodeException ;
35+ import org .opensearch .action .bulk .BulkRequest ;
36+ import org .opensearch .action .bulk .BulkResponse ;
3237import org .opensearch .action .support .ActionFilters ;
38+ import org .opensearch .action .update .UpdateRequest ;
3339import org .opensearch .client .Client ;
3440import org .opensearch .cluster .ClusterName ;
3541import org .opensearch .cluster .service .ClusterService ;
4248import org .opensearch .ml .cluster .DiscoveryNodeHelper ;
4349import org .opensearch .ml .common .FunctionName ;
4450import org .opensearch .ml .common .MLModel ;
51+ import org .opensearch .ml .common .model .MLModelState ;
4552import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodeResponse ;
4653import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesResponse ;
4754import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsRequest ;
@@ -172,6 +179,129 @@ public void setup() throws IOException {
172179 }).when (mlModelManager ).getModel (any (), any (), any (), any (), isA (ActionListener .class ));
173180 }
174181
182+ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel () {
183+ String modelId = "someModelId" ;
184+ MLModel mlModel = MLModel
185+ .builder ()
186+ .user (User .parse (USER_STRING ))
187+ .modelGroupId ("111" )
188+ .version ("111" )
189+ .name ("Test Model" )
190+ .modelId (modelId )
191+ .algorithm (FunctionName .BATCH_RCF )
192+ .content ("content" )
193+ .totalChunks (2 )
194+ .isHidden (true )
195+ .build ();
196+
197+ doAnswer (invocation -> {
198+ ActionListener <MLModel > listener = invocation .getArgument (4 );
199+ listener .onResponse (mlModel );
200+ return null ;
201+ }).when (mlModelManager ).getModel (any (), any (), any (), any (), isA (ActionListener .class ));
202+
203+ doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
204+
205+ List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
206+ List <FailedNodeException > failuresList = new ArrayList <>();
207+ MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
208+
209+ // Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
210+ doAnswer (invocation -> {
211+ ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
212+ listener .onResponse (nodesResponse );
213+ return null ;
214+ }).when (client ).execute (any (), any (), isA (ActionListener .class ));
215+
216+ ArgumentCaptor <BulkRequest > bulkRequestCaptor = ArgumentCaptor .forClass (BulkRequest .class );
217+
218+ // mock the bulk response that can be captured for inspecting the contents of the write to index
219+ doAnswer (invocation -> {
220+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
221+ BulkResponse bulkResponse = mock (BulkResponse .class );
222+ when (bulkResponse .hasFailures ()).thenReturn (false );
223+ listener .onResponse (bulkResponse );
224+ return null ;
225+ }).when (client ).bulk (bulkRequestCaptor .capture (), any (ActionListener .class ));
226+
227+ String [] modelIds = new String [] { modelId };
228+ String [] nodeIds = new String [] { "test_node_id1" , "test_node_id2" };
229+ MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
230+
231+ transportUndeployModelsAction .doExecute (task , request , actionListener );
232+
233+ BulkRequest capturedBulkRequest = bulkRequestCaptor .getValue ();
234+ assertEquals (1 , capturedBulkRequest .numberOfActions ());
235+ UpdateRequest updateRequest = (UpdateRequest ) capturedBulkRequest .requests ().get (0 );
236+
237+ @ SuppressWarnings ("unchecked" )
238+ Map <String , Object > updateDoc = updateRequest .doc ().sourceAsMap ();
239+ String modelIdFromBulkRequest = updateRequest .id ();
240+ String indexNameFromBulkRequest = updateRequest .index ();
241+
242+ assertEquals ("Check that the write happened at the model index" , ML_MODEL_INDEX , indexNameFromBulkRequest );
243+ assertEquals ("Check that the result bulk write hit this specific modelId" , modelId , modelIdFromBulkRequest );
244+
245+ assertEquals (MLModelState .UNDEPLOYED .name (), updateDoc .get (MLModel .MODEL_STATE_FIELD ));
246+ assertEquals (0 , updateDoc .get (MLModel .CURRENT_WORKER_NODE_COUNT_FIELD ));
247+ assertEquals (0 , updateDoc .get (MLModel .PLANNING_WORKER_NODE_COUNT_FIELD ));
248+ assertEquals (List .of (), updateDoc .get (MLModel .PLANNING_WORKER_NODES_FIELD ));
249+ assertTrue (updateDoc .containsKey (MLModel .LAST_UPDATED_TIME_FIELD ));
250+
251+ verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
252+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
253+ }
254+
255+ public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel () {
256+ String modelId = "someModelId" ;
257+ MLModel mlModel = MLModel
258+ .builder ()
259+ .user (User .parse (USER_STRING ))
260+ .modelGroupId ("111" )
261+ .version ("111" )
262+ .name ("Test Model" )
263+ .modelId (modelId )
264+ .algorithm (FunctionName .BATCH_RCF )
265+ .content ("content" )
266+ .totalChunks (2 )
267+ .isHidden (true )
268+ .build ();
269+
270+ doAnswer (invocation -> {
271+ ActionListener <MLModel > listener = invocation .getArgument (4 );
272+ listener .onResponse (mlModel );
273+ return null ;
274+ }).when (mlModelManager ).getModel (any (), any (), any (), any (), isA (ActionListener .class ));
275+
276+ doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
277+
278+ List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
279+ responseList .add (mock (MLUndeployModelNodeResponse .class ));
280+ responseList .add (mock (MLUndeployModelNodeResponse .class ));
281+ List <FailedNodeException > failuresList = new ArrayList <>();
282+ failuresList .add (mock (FailedNodeException .class ));
283+ failuresList .add (mock (FailedNodeException .class ));
284+
285+ MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
286+
287+ // Send back a response with nodes associated to the model
288+ doAnswer (invocation -> {
289+ ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
290+ listener .onResponse (nodesResponse );
291+ return null ;
292+ }).when (client ).execute (any (), any (), isA (ActionListener .class ));
293+
294+ String [] modelIds = new String [] { modelId };
295+ String [] nodeIds = new String [] { "test_node_id1" , "test_node_id2" };
296+ MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
297+
298+ transportUndeployModelsAction .doExecute (task , request , actionListener );
299+
300+ verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
301+ // Check that no bulk write occurred Since there were nodes servicing the model
302+ verify (client , never ()).bulk (any (BulkRequest .class ), any (ActionListener .class ));
303+ }
304+
175305 public void testHiddenModelSuccess () {
176306 MLModel mlModel = MLModel
177307 .builder ()
@@ -194,16 +324,28 @@ public void testHiddenModelSuccess() {
194324 List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
195325 List <FailedNodeException > failuresList = new ArrayList <>();
196326 MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
327+
197328 doAnswer (invocation -> {
198329 ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
199330 listener .onResponse (response );
200331 return null ;
201332 }).when (client ).execute (any (), any (), isA (ActionListener .class ));
202333
334+ // Mock the client.bulk call
335+ doAnswer (invocation -> {
336+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
337+ BulkResponse bulkResponse = mock (BulkResponse .class );
338+ when (bulkResponse .hasFailures ()).thenReturn (false );
339+ listener .onResponse (bulkResponse );
340+ return null ;
341+ }).when (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
342+
203343 doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
204344 MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
205345 transportUndeployModelsAction .doExecute (task , request , actionListener );
346+
206347 verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
348+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
207349 }
208350
209351 public void testHiddenModelPermissionError () {
@@ -257,9 +399,19 @@ public void testDoExecute() {
257399 listener .onResponse (response );
258400 return null ;
259401 }).when (client ).execute (any (), any (), isA (ActionListener .class ));
402+ // Mock the client.bulk call
403+ doAnswer (invocation -> {
404+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
405+ BulkResponse bulkResponse = mock (BulkResponse .class );
406+ when (bulkResponse .hasFailures ()).thenReturn (false );
407+ listener .onResponse (bulkResponse );
408+ return null ;
409+ }).when (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
410+
260411 MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds , null );
261412 transportUndeployModelsAction .doExecute (task , request , actionListener );
262413 verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
414+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
263415 }
264416
265417 public void testDoExecute_modelAccessControl_notEnabled () {
0 commit comments