Skip to content

Commit 839785f

Browse files
fix remote register model / circuit breaker 500 (#2264) (#2273)
* move memory CB check into try block to catch exception and hand to listener for register remote model Signed-off-by: Henry Lindeman <[email protected]> * add test that memory cb exception is caught by action listener Signed-off-by: Henry Lindeman <[email protected]> * unthrow priviledgedExceptionAction Signed-off-by: Henry Lindeman <[email protected]> --------- Signed-off-by: Henry Lindeman <[email protected]> (cherry picked from commit 30642e6) Co-authored-by: Henry Lindeman <[email protected]>
1 parent eeba1c3 commit 839785f

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ public void registerMLRemoteModel(
345345
MLTask mlTask,
346346
ActionListener<MLRegisterModelResponse> listener
347347
) {
348-
checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
349348
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
349+
checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
350350
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
351351
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment();
352352
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import org.opensearch.core.action.ActionListener;
8383
import org.opensearch.core.xcontent.NamedXContentRegistry;
8484
import org.opensearch.ml.breaker.MLCircuitBreakerService;
85+
import org.opensearch.ml.breaker.MemoryCircuitBreaker;
8586
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
8687
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
8788
import org.opensearch.ml.common.FunctionName;
@@ -112,6 +113,7 @@
112113
import org.opensearch.ml.stats.MLStats;
113114
import org.opensearch.ml.stats.suppliers.CounterSupplier;
114115
import org.opensearch.ml.task.MLTaskManager;
116+
import org.opensearch.monitor.jvm.JvmService;
115117
import org.opensearch.script.ScriptService;
116118
import org.opensearch.test.OpenSearchTestCase;
117119
import org.opensearch.threadpool.ThreadPool;
@@ -449,6 +451,23 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
449451
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
450452
}
451453

454+
public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() {
455+
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
456+
MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class));
457+
String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!";
458+
when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage));
459+
460+
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
461+
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
462+
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
463+
464+
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
465+
verify(listener, times(1)).onFailure(argCaptor.capture());
466+
Exception e = argCaptor.getValue();
467+
assertTrue(e instanceof MLLimitExceededException);
468+
assertEquals(memCBIsOpenMessage, e.getMessage());
469+
}
470+
452471
public void testIndexRemoteModel() throws PrivilegedActionException {
453472
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
454473
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());

0 commit comments

Comments
 (0)