|
21 | 21 | import static org.mockito.Mockito.times; |
22 | 22 | import static org.mockito.Mockito.verify; |
23 | 23 | import static org.mockito.Mockito.when; |
| 24 | +import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; |
24 | 25 | import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES; |
25 | 26 | import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH; |
26 | 27 | import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; |
| 28 | +import static org.opensearch.ml.model.MLModelManager.TIMEOUT_IN_MILLIS; |
27 | 29 | import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; |
28 | 30 | import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; |
29 | 31 | import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; |
|
49 | 51 | import java.net.URISyntaxException; |
50 | 52 | import java.nio.charset.StandardCharsets; |
51 | 53 | import java.nio.file.Path; |
| 54 | +import java.security.PrivilegedActionException; |
52 | 55 | import java.util.Arrays; |
53 | 56 | import java.util.Base64; |
| 57 | +import java.util.Collections; |
54 | 58 | import java.util.HashMap; |
55 | 59 | import java.util.List; |
56 | 60 | import java.util.Map; |
@@ -168,6 +172,9 @@ public class MLModelManagerTests extends OpenSearchTestCase { |
168 | 172 | @Mock |
169 | 173 | private ScriptService scriptService; |
170 | 174 |
|
| 175 | + @Mock |
| 176 | + private MLTask pretrainedMLTask; |
| 177 | + |
171 | 178 | @Before |
172 | 179 | public void setup() throws URISyntaxException { |
173 | 180 | String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; |
@@ -373,6 +380,35 @@ public void testRegisterMLModel_DownloadModelFileFailure() { |
373 | 380 | verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any()); |
374 | 381 | } |
375 | 382 |
|
| 383 | + public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException { |
| 384 | + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
| 385 | + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); |
| 386 | + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); |
| 387 | + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); |
| 388 | + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); |
| 389 | + MLRegisterModelInput pretrainedInput = mockPretrainedInput(); |
| 390 | + doAnswer(invocation -> { |
| 391 | + ActionListener<MLRegisterModelInput> listener = (ActionListener<MLRegisterModelInput>) invocation.getArguments()[2]; |
| 392 | + listener.onResponse(pretrainedInput); |
| 393 | + return null; |
| 394 | + }).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any()); |
| 395 | + MLTask pretrainedTask = MLTask |
| 396 | + .builder() |
| 397 | + .taskId("pretrained") |
| 398 | + .modelId("pretrained") |
| 399 | + .functionName(FunctionName.TEXT_EMBEDDING) |
| 400 | + .build(); |
| 401 | + modelManager.registerMLModel(pretrainedInput, pretrainedTask); |
| 402 | + assertEquals(pretrainedTask.getFunctionName(), FunctionName.SPARSE_ENCODING); |
| 403 | + verify(mlTaskManager) |
| 404 | + .updateMLTask( |
| 405 | + eq("pretrained"), |
| 406 | + eq(ImmutableMap.of(FUNCTION_NAME_FIELD, FunctionName.SPARSE_ENCODING)), |
| 407 | + eq((long) TIMEOUT_IN_MILLIS), |
| 408 | + eq(false) |
| 409 | + ); |
| 410 | + } |
| 411 | + |
376 | 412 | @Ignore |
377 | 413 | public void testRegisterMLModel_DownloadModelFile() throws IOException { |
378 | 414 | doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); |
@@ -916,4 +952,15 @@ private MLRegisterModelMetaInput prepareRequest() { |
916 | 952 | .build(); |
917 | 953 | return input; |
918 | 954 | } |
| 955 | + |
| 956 | + private MLRegisterModelInput mockPretrainedInput() { |
| 957 | + return MLRegisterModelInput |
| 958 | + .builder() |
| 959 | + .modelName(modelName) |
| 960 | + .version(version) |
| 961 | + .modelGroupId("modelGroupId") |
| 962 | + .modelFormat(modelFormat) |
| 963 | + .functionName(FunctionName.SPARSE_ENCODING) |
| 964 | + .build(); |
| 965 | + } |
919 | 966 | } |
0 commit comments