|
82 | 82 | import org.opensearch.core.action.ActionListener; |
83 | 83 | import org.opensearch.core.xcontent.NamedXContentRegistry; |
84 | 84 | import org.opensearch.ml.breaker.MLCircuitBreakerService; |
| 85 | +import org.opensearch.ml.breaker.MemoryCircuitBreaker; |
85 | 86 | import org.opensearch.ml.breaker.ThresholdCircuitBreaker; |
86 | 87 | import org.opensearch.ml.cluster.DiscoveryNodeHelper; |
87 | 88 | import org.opensearch.ml.common.FunctionName; |
|
112 | 113 | import org.opensearch.ml.stats.MLStats; |
113 | 114 | import org.opensearch.ml.stats.suppliers.CounterSupplier; |
114 | 115 | import org.opensearch.ml.task.MLTaskManager; |
| 116 | +import org.opensearch.monitor.jvm.JvmService; |
115 | 117 | import org.opensearch.script.ScriptService; |
116 | 118 | import org.opensearch.test.OpenSearchTestCase; |
117 | 119 | import org.opensearch.threadpool.ThreadPool; |
@@ -449,6 +451,23 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException { |
449 | 451 | verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); |
450 | 452 | } |
451 | 453 |
|
| 454 | + public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() { |
| 455 | + ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class); |
| 456 | + MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class)); |
| 457 | + String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!"; |
| 458 | + when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage)); |
| 459 | + |
| 460 | + MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); |
| 461 | + MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); |
| 462 | + modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); |
| 463 | + |
| 464 | + ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class); |
| 465 | + verify(listener, times(1)).onFailure(argCaptor.capture()); |
| 466 | + Exception e = argCaptor.getValue(); |
| 467 | + assertTrue(e instanceof MLLimitExceededException); |
| 468 | + assertEquals(memCBIsOpenMessage, e.getMessage()); |
| 469 | + } |
| 470 | + |
452 | 471 | public void testIndexRemoteModel() throws PrivilegedActionException { |
453 | 472 | ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class); |
454 | 473 | doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
|
0 commit comments