1313import static org .mockito .Mockito .doReturn ;
1414import static org .mockito .Mockito .doThrow ;
1515import static org .mockito .Mockito .mock ;
16+ import static org .mockito .Mockito .never ;
1617import static org .mockito .Mockito .spy ;
1718import static org .mockito .Mockito .verify ;
1819import static org .mockito .Mockito .when ;
20+ import static org .opensearch .ml .common .CommonValue .ML_MODEL_INDEX ;
1921import static org .opensearch .ml .task .MLPredictTaskRunnerTests .USER_STRING ;
2022
2123import java .io .IOException ;
2224import java .util .ArrayList ;
2325import java .util .List ;
26+ import java .util .Map ;
2427
2528import org .junit .Before ;
2629import org .junit .Rule ;
2932import org .mockito .Mock ;
3033import org .mockito .MockitoAnnotations ;
3134import org .opensearch .action .FailedNodeException ;
35+ import org .opensearch .action .bulk .BulkRequest ;
36+ import org .opensearch .action .bulk .BulkResponse ;
3237import org .opensearch .action .support .ActionFilters ;
38+ import org .opensearch .action .update .UpdateRequest ;
3339import org .opensearch .client .Client ;
3440import org .opensearch .cluster .ClusterName ;
3541import org .opensearch .cluster .service .ClusterService ;
4248import org .opensearch .ml .cluster .DiscoveryNodeHelper ;
4349import org .opensearch .ml .common .FunctionName ;
4450import org .opensearch .ml .common .MLModel ;
51+ import org .opensearch .ml .common .model .MLModelState ;
4552import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodeResponse ;
4653import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesResponse ;
4754import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsRequest ;
@@ -164,6 +171,129 @@ public void setup() throws IOException {
164171 }).when (mlModelManager ).getModel (any (), any (), any (), isA (ActionListener .class ));
165172 }
166173
174+ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel () {
175+ String modelId = "someModelId" ;
176+ MLModel mlModel = MLModel
177+ .builder ()
178+ .user (User .parse (USER_STRING ))
179+ .modelGroupId ("111" )
180+ .version ("111" )
181+ .name ("Test Model" )
182+ .modelId (modelId )
183+ .algorithm (FunctionName .BATCH_RCF )
184+ .content ("content" )
185+ .totalChunks (2 )
186+ .isHidden (true )
187+ .build ();
188+
189+ doAnswer (invocation -> {
190+ ActionListener <MLModel > listener = invocation .getArgument (3 );
191+ listener .onResponse (mlModel );
192+ return null ;
193+ }).when (mlModelManager ).getModel (any (), any (), any (), isA (ActionListener .class ));
194+
195+ doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
196+
197+ List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
198+ List <FailedNodeException > failuresList = new ArrayList <>();
199+ MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
200+
201+ // Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
202+ doAnswer (invocation -> {
203+ ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
204+ listener .onResponse (nodesResponse );
205+ return null ;
206+ }).when (client ).execute (any (), any (), isA (ActionListener .class ));
207+
208+ ArgumentCaptor <BulkRequest > bulkRequestCaptor = ArgumentCaptor .forClass (BulkRequest .class );
209+
210+ // mock the bulk response that can be captured for inspecting the contents of the write to index
211+ doAnswer (invocation -> {
212+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
213+ BulkResponse bulkResponse = mock (BulkResponse .class );
214+ when (bulkResponse .hasFailures ()).thenReturn (false );
215+ listener .onResponse (bulkResponse );
216+ return null ;
217+ }).when (client ).bulk (bulkRequestCaptor .capture (), any (ActionListener .class ));
218+
219+ String [] modelIds = new String [] { modelId };
220+ String [] nodeIds = new String [] { "test_node_id1" , "test_node_id2" };
221+ MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds );
222+
223+ transportUndeployModelsAction .doExecute (task , request , actionListener );
224+
225+ BulkRequest capturedBulkRequest = bulkRequestCaptor .getValue ();
226+ assertEquals (1 , capturedBulkRequest .numberOfActions ());
227+ UpdateRequest updateRequest = (UpdateRequest ) capturedBulkRequest .requests ().get (0 );
228+
229+ @ SuppressWarnings ("unchecked" )
230+ Map <String , Object > updateDoc = updateRequest .doc ().sourceAsMap ();
231+ String modelIdFromBulkRequest = updateRequest .id ();
232+ String indexNameFromBulkRequest = updateRequest .index ();
233+
234+ assertEquals ("Check that the write happened at the model index" , ML_MODEL_INDEX , indexNameFromBulkRequest );
235+ assertEquals ("Check that the result bulk write hit this specific modelId" , modelId , modelIdFromBulkRequest );
236+
237+ assertEquals (MLModelState .UNDEPLOYED .name (), updateDoc .get (MLModel .MODEL_STATE_FIELD ));
238+ assertEquals (0 , updateDoc .get (MLModel .CURRENT_WORKER_NODE_COUNT_FIELD ));
239+ assertEquals (0 , updateDoc .get (MLModel .PLANNING_WORKER_NODE_COUNT_FIELD ));
240+ assertEquals (List .of (), updateDoc .get (MLModel .PLANNING_WORKER_NODES_FIELD ));
241+ assertTrue (updateDoc .containsKey (MLModel .LAST_UPDATED_TIME_FIELD ));
242+
243+ verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
244+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
245+ }
246+
247+ public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel () {
248+ String modelId = "someModelId" ;
249+ MLModel mlModel = MLModel
250+ .builder ()
251+ .user (User .parse (USER_STRING ))
252+ .modelGroupId ("111" )
253+ .version ("111" )
254+ .name ("Test Model" )
255+ .modelId (modelId )
256+ .algorithm (FunctionName .BATCH_RCF )
257+ .content ("content" )
258+ .totalChunks (2 )
259+ .isHidden (true )
260+ .build ();
261+
262+ doAnswer (invocation -> {
263+ ActionListener <MLModel > listener = invocation .getArgument (3 );
264+ listener .onResponse (mlModel );
265+ return null ;
266+ }).when (mlModelManager ).getModel (any (), any (), any (), isA (ActionListener .class ));
267+
268+ doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
269+
270+ List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
271+ responseList .add (mock (MLUndeployModelNodeResponse .class ));
272+ responseList .add (mock (MLUndeployModelNodeResponse .class ));
273+ List <FailedNodeException > failuresList = new ArrayList <>();
274+ failuresList .add (mock (FailedNodeException .class ));
275+ failuresList .add (mock (FailedNodeException .class ));
276+
277+ MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
278+
279+ // Send back a response with nodes associated to the model
280+ doAnswer (invocation -> {
281+ ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
282+ listener .onResponse (nodesResponse );
283+ return null ;
284+ }).when (client ).execute (any (), any (), isA (ActionListener .class ));
285+
286+ String [] modelIds = new String [] { modelId };
287+ String [] nodeIds = new String [] { "test_node_id1" , "test_node_id2" };
288+ MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds );
289+
290+ transportUndeployModelsAction .doExecute (task , request , actionListener );
291+
292+ verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
293+ // Check that no bulk write occurred Since there were nodes servicing the model
294+ verify (client , never ()).bulk (any (BulkRequest .class ), any (ActionListener .class ));
295+ }
296+
167297 public void testHiddenModelSuccess () {
168298 MLModel mlModel = MLModel
169299 .builder ()
@@ -186,16 +316,28 @@ public void testHiddenModelSuccess() {
186316 List <MLUndeployModelNodeResponse > responseList = new ArrayList <>();
187317 List <FailedNodeException > failuresList = new ArrayList <>();
188318 MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse (clusterName , responseList , failuresList );
319+
189320 doAnswer (invocation -> {
190321 ActionListener <MLUndeployModelNodesResponse > listener = invocation .getArgument (2 );
191322 listener .onResponse (response );
192323 return null ;
193324 }).when (client ).execute (any (), any (), isA (ActionListener .class ));
194325
326+ // Mock the client.bulk call
327+ doAnswer (invocation -> {
328+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
329+ BulkResponse bulkResponse = mock (BulkResponse .class );
330+ when (bulkResponse .hasFailures ()).thenReturn (false );
331+ listener .onResponse (bulkResponse );
332+ return null ;
333+ }).when (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
334+
195335 doReturn (true ).when (transportUndeployModelsAction ).isSuperAdminUserWrapper (clusterService , client );
196336 MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds );
197337 transportUndeployModelsAction .doExecute (task , request , actionListener );
338+
198339 verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
340+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
199341 }
200342
201343 public void testHiddenModelPermissionError () {
@@ -249,9 +391,19 @@ public void testDoExecute() {
249391 listener .onResponse (response );
250392 return null ;
251393 }).when (client ).execute (any (), any (), isA (ActionListener .class ));
394+ // Mock the client.bulk call
395+ doAnswer (invocation -> {
396+ ActionListener <BulkResponse > listener = invocation .getArgument (1 );
397+ BulkResponse bulkResponse = mock (BulkResponse .class );
398+ when (bulkResponse .hasFailures ()).thenReturn (false );
399+ listener .onResponse (bulkResponse );
400+ return null ;
401+ }).when (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
402+
252403 MLUndeployModelsRequest request = new MLUndeployModelsRequest (modelIds , nodeIds );
253404 transportUndeployModelsAction .doExecute (task , request , actionListener );
254405 verify (actionListener ).onResponse (any (MLUndeployModelsResponse .class ));
406+ verify (client ).bulk (any (BulkRequest .class ), any (ActionListener .class ));
255407 }
256408
257409 public void testDoExecute_modelAccessControl_notEnabled () {
0 commit comments