Skip to content

Commit 96778c0

Browse files
bug fix: to exclude remote model from deployment check (#4114) (#4116)
1 parent ccf9d6b commit 96778c0

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

plugin/src/main/java/org/opensearch/ml/helper/MemoryEmbeddingHelper.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ public void validateEmbeddingModelState(String modelId, FunctionName modelType,
177177
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
178178
ActionListener<MLModel> wrappedListener = ActionListener.runBefore(ActionListener.wrap(model -> {
179179
MLModelState modelState = model.getModelState();
180-
if (model.getAlgorithm() == FunctionName.REMOTE
181-
|| (modelState != MLModelState.DEPLOYED && modelState != MLModelState.PARTIALLY_DEPLOYED)) {
180+
if (model.getAlgorithm() != FunctionName.REMOTE
181+
&& (modelState != MLModelState.DEPLOYED && modelState != MLModelState.PARTIALLY_DEPLOYED)) {
182182
listener
183183
.onFailure(
184184
new IllegalStateException(

plugin/src/test/java/org/opensearch/ml/helper/MemoryEmbeddingHelperTests.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,21 @@ public void testValidateEmbeddingModelStateDeployed() {
293293
verify(booleanListener).onResponse(true);
294294
}
295295

296+
@Test
297+
public void testValidateEmbeddingModelStatePartiallyDeployed() {
298+
when(mlModel.getModelState()).thenReturn(MLModelState.PARTIALLY_DEPLOYED);
299+
300+
doAnswer(invocation -> {
301+
ActionListener<MLModel> listener = invocation.getArgument(1);
302+
listener.onResponse(mlModel);
303+
return null;
304+
}).when(mlModelManager).getModel(eq("model-123"), any());
305+
306+
helper.validateEmbeddingModelState("model-123", FunctionName.TEXT_EMBEDDING, booleanListener);
307+
308+
verify(booleanListener).onResponse(true);
309+
}
310+
296311
@Test
297312
public void testValidateEmbeddingModelStateNotDeployed() {
298313
when(mlModel.getModelState()).thenReturn(MLModelState.REGISTERED);
@@ -311,8 +326,46 @@ public void testValidateEmbeddingModelStateNotDeployed() {
311326
assertTrue(exceptionCaptor.getValue().getMessage().contains("DEPLOYED"));
312327
}
313328

329+
@Test
330+
public void testValidateEmbeddingModelStateDeploying() {
331+
when(mlModel.getModelState()).thenReturn(MLModelState.DEPLOYING);
332+
333+
doAnswer(invocation -> {
334+
ActionListener<MLModel> listener = invocation.getArgument(1);
335+
listener.onResponse(mlModel);
336+
return null;
337+
}).when(mlModelManager).getModel(eq("model-123"), any());
338+
339+
helper.validateEmbeddingModelState("model-123", FunctionName.TEXT_EMBEDDING, booleanListener);
340+
341+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
342+
verify(booleanListener).onFailure(exceptionCaptor.capture());
343+
assertTrue(exceptionCaptor.getValue() instanceof IllegalStateException);
344+
assertTrue(exceptionCaptor.getValue().getMessage().contains("DEPLOYED"));
345+
}
346+
314347
@Test
315348
public void testValidateEmbeddingModelStateRemoteModel() {
349+
// Set up a model that has REMOTE algorithm but any state (e.g., REGISTERED)
350+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE);
351+
when(mlModel.getModelState()).thenReturn(MLModelState.REGISTERED);
352+
353+
doAnswer(invocation -> {
354+
ActionListener<MLModel> listener = invocation.getArgument(1);
355+
listener.onResponse(mlModel);
356+
return null;
357+
}).when(mlModelManager).getModel(eq("model-123"), any());
358+
359+
// Pass TEXT_EMBEDDING as modelType to avoid early return, but model itself is REMOTE
360+
helper.validateEmbeddingModelState("model-123", FunctionName.TEXT_EMBEDDING, booleanListener);
361+
362+
verify(booleanListener).onResponse(true);
363+
verify(mlModelManager).getModel(eq("model-123"), any());
364+
}
365+
366+
@Test
367+
public void testValidateEmbeddingModelStateRemoteModelTypeEarlyReturn() {
368+
// Test the early return path when modelType itself is REMOTE
316369
helper.validateEmbeddingModelState("remote-model", FunctionName.REMOTE, booleanListener);
317370
verify(booleanListener).onResponse(true);
318371
verify(mlModelManager, never()).getModel(any(), any());

0 commit comments

Comments
 (0)