| 
75 | 75 | import org.mockito.ArgumentCaptor;  | 
76 | 76 | import org.mockito.Mock;  | 
77 | 77 | import org.mockito.MockitoAnnotations;  | 
 | 78 | +import org.opensearch.OpenSearchStatusException;  | 
78 | 79 | import org.opensearch.action.get.GetRequest;  | 
79 | 80 | import org.opensearch.action.get.GetResponse;  | 
80 | 81 | import org.opensearch.action.index.IndexResponse;  | 
 | 
92 | 93 | import org.opensearch.core.common.breaker.CircuitBreakingException;  | 
93 | 94 | import org.opensearch.core.common.bytes.BytesReference;  | 
94 | 95 | import org.opensearch.core.index.shard.ShardId;  | 
 | 96 | +import org.opensearch.core.rest.RestStatus;  | 
95 | 97 | import org.opensearch.core.xcontent.NamedXContentRegistry;  | 
96 | 98 | import org.opensearch.core.xcontent.ToXContent;  | 
97 | 99 | import org.opensearch.core.xcontent.XContentBuilder;  | 
 | 100 | +import org.opensearch.index.IndexNotFoundException;  | 
98 | 101 | import org.opensearch.index.get.GetResult;  | 
99 | 102 | import org.opensearch.ml.breaker.MLCircuitBreakerService;  | 
100 | 103 | import org.opensearch.ml.breaker.ThresholdCircuitBreaker;  | 
@@ -492,6 +495,46 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOExce  | 
492 | 495 |         verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());  | 
493 | 496 |     }  | 
494 | 497 | 
 
  | 
 | 498 | +    @Test  | 
 | 499 | +    public void testRegisterMLRemoteModelModelGroupNotFoundException() throws PrivilegedActionException, IOException {  | 
 | 500 | +        // Create listener and capture the failure  | 
 | 501 | +        ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);  | 
 | 502 | +        ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);  | 
 | 503 | + | 
 | 504 | +        // Setup mocks  | 
 | 505 | +        doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());  | 
 | 506 | +        when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);  | 
 | 507 | +        when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);  | 
 | 508 | +        when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));  | 
 | 509 | +        when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);  | 
 | 510 | + | 
 | 511 | +        // Create test inputs  | 
 | 512 | +        MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);  | 
 | 513 | +        MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();  | 
 | 514 | + | 
 | 515 | +        // Mock index handler  | 
 | 516 | +        mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);  | 
 | 517 | + | 
 | 518 | +        // Mock client.get() to throw IndexNotFoundException  | 
 | 519 | +        doAnswer(invocation -> {  | 
 | 520 | +            ActionListener<GetResponse> getModelGroupListener = invocation.getArgument(1);  | 
 | 521 | +            getModelGroupListener.onFailure(new IndexNotFoundException("Test", "test"));  | 
 | 522 | +            return null;  | 
 | 523 | +        }).when(client).get(any(), any());  | 
 | 524 | + | 
 | 525 | +        // Execute method under test  | 
 | 526 | +        modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener);  | 
 | 527 | + | 
 | 528 | +        // Verify the listener's onFailure was called with correct exception  | 
 | 529 | +        verify(listener).onFailure(exceptionCaptor.capture());  | 
 | 530 | +        Exception exception = exceptionCaptor.getValue();  | 
 | 531 | + | 
 | 532 | +        // Verify exception type and message  | 
 | 533 | +        assertTrue(exception instanceof OpenSearchStatusException);  | 
 | 534 | +        assertEquals("Model group not found", exception.getMessage());  | 
 | 535 | +        assertEquals(RestStatus.NOT_FOUND, ((OpenSearchStatusException) exception).status());  | 
 | 536 | +    }  | 
 | 537 | + | 
495 | 538 |     public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException {  | 
496 | 539 |         ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);  | 
497 | 540 |         doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());  | 
 | 
0 commit comments