| 
18 | 18 | import static org.mockito.Mockito.verify;  | 
19 | 19 | import static org.mockito.Mockito.when;  | 
20 | 20 | import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;  | 
 | 21 | +import static org.opensearch.ml.common.CommonValue.NOT_FOUND;  | 
21 | 22 | import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;  | 
22 | 23 | 
 
  | 
23 | 24 | import java.io.IOException;  | 
24 | 25 | import java.util.ArrayList;  | 
 | 26 | +import java.util.HashMap;  | 
25 | 27 | import java.util.List;  | 
26 | 28 | import java.util.Map;  | 
27 | 29 | 
 
  | 
@@ -348,6 +350,63 @@ public void testHiddenModelSuccess() {  | 
348 | 350 |         verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));  | 
349 | 351 |     }  | 
350 | 352 | 
 
  | 
 | 353 | +    public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {  | 
 | 354 | +        MLModel mlModel = MLModel  | 
 | 355 | +            .builder()  | 
 | 356 | +            .user(User.parse(USER_STRING))  | 
 | 357 | +            .modelGroupId("111")  | 
 | 358 | +            .version("111")  | 
 | 359 | +            .name(this.modelIds[0])  | 
 | 360 | +            .modelId(this.modelIds[0])  | 
 | 361 | +            .algorithm(FunctionName.BATCH_RCF)  | 
 | 362 | +            .content("content")  | 
 | 363 | +            .totalChunks(2)  | 
 | 364 | +            .isHidden(true)  | 
 | 365 | +            .build();  | 
 | 366 | + | 
 | 367 | +        // Mock MLModel manager response  | 
 | 368 | +        doAnswer(invocation -> {  | 
 | 369 | +            ActionListener<MLModel> listener = invocation.getArgument(4);  | 
 | 370 | +            listener.onResponse(mlModel);  | 
 | 371 | +            return null;  | 
 | 372 | +        }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));  | 
 | 373 | + | 
 | 374 | +        doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);  | 
 | 375 | + | 
 | 376 | +        List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();  | 
 | 377 | + | 
 | 378 | +        for (String nodeId : this.nodeIds) {  | 
 | 379 | +            Map<String, String> stats = new HashMap<>();  | 
 | 380 | +            stats.put(this.modelIds[0], NOT_FOUND);  | 
 | 381 | +            MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class);  | 
 | 382 | +            when(nodeResponse.getModelUndeployStatus()).thenReturn(stats);  | 
 | 383 | +            responseList.add(nodeResponse);  | 
 | 384 | +        }  | 
 | 385 | + | 
 | 386 | +        List<FailedNodeException> failuresList = new ArrayList<>();  | 
 | 387 | +        MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);  | 
 | 388 | + | 
 | 389 | +        doAnswer(invocation -> {  | 
 | 390 | +            ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);  | 
 | 391 | +            listener.onResponse(nodesResponse);  | 
 | 392 | +            return null;  | 
 | 393 | +        }).when(client).execute(any(), any(), isA(ActionListener.class));  | 
 | 394 | + | 
 | 395 | +        doAnswer(invocation -> {  | 
 | 396 | +            ActionListener<BulkResponse> listener = invocation.getArgument(1);  | 
 | 397 | +            listener.onResponse(mock(BulkResponse.class));  | 
 | 398 | +            return null;  | 
 | 399 | +        }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));  | 
 | 400 | + | 
 | 401 | +        MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);  | 
 | 402 | + | 
 | 403 | +        transportUndeployModelsAction.doExecute(task, request, actionListener);  | 
 | 404 | + | 
 | 405 | +        // Verify that bulk request was fired because all nodes reported "not_found"  | 
 | 406 | +        verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));  | 
 | 407 | +        verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));  | 
 | 408 | +    }  | 
 | 409 | + | 
351 | 410 |     public void testHiddenModelPermissionError() {  | 
352 | 411 |         MLModel mlModel = MLModel  | 
353 | 412 |             .builder()  | 
 | 
0 commit comments