Skip to content

Commit 17181b7

Browse files
authored
backport to 2.11 (#1639)
Signed-off-by: xinyual <[email protected]>
1 parent 6ef198f commit 17181b7

File tree

4 files changed

+67
-7
lines changed

4 files changed

+67
-7
lines changed

common/src/main/java/org/opensearch/ml/common/MLTask.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ public class MLTask implements ToXContentObject, Writeable {
5050
@Setter
5151
private String modelId;
5252
private final MLTaskType taskType;
53-
private final FunctionName functionName;
53+
@Setter
54+
private FunctionName functionName;
5455
@Setter
5556
private MLTaskState state;
5657
private final MLInputDataType inputType;

ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,15 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi
8282

8383
MLRegisterModelInput.MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder();
8484

85-
builder.modelName(modelName)
86-
.version(version)
87-
.url(modelZipFileUrl)
88-
.deployModel(deployModel)
89-
.modelNodeIds(modelNodeIds)
90-
.modelGroupId(modelGroupId);
85+
builder
86+
.modelName(modelName)
87+
.version(version)
88+
.url(modelZipFileUrl)
89+
.deployModel(deployModel)
90+
.modelNodeIds(modelNodeIds)
91+
.modelGroupId(modelGroupId)
92+
.functionName(FunctionName.from((String) config.get("model_task_type")));
93+
9194
config.entrySet().forEach(entry -> {
9295
switch (entry.getKey().toString()) {
9396
case MLRegisterModelInput.MODEL_FORMAT_FIELD:

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import static org.opensearch.ml.common.CommonValue.UNDEPLOYED;
1616
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
1717
import static org.opensearch.ml.common.MLTask.ERROR_FIELD;
18+
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
1819
import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;
1920
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
2021
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
@@ -755,6 +756,14 @@ private void registerPrebuiltModel(MLRegisterModelInput registerModelInput, MLTa
755756
throw new IllegalArgumentException("This model is not in the pre-trained model list, please check your parameters.");
756757
}
757758
modelHelper.downloadPrebuiltModelConfig(taskId, registerModelInput, ActionListener.wrap(mlRegisterModelInput -> {
759+
mlTask.setFunctionName(mlRegisterModelInput.getFunctionName());
760+
mlTaskManager
761+
.updateMLTask(
762+
taskId,
763+
ImmutableMap.of(FUNCTION_NAME_FIELD, mlRegisterModelInput.getFunctionName()),
764+
TIMEOUT_IN_MILLIS,
765+
false
766+
);
758767
registerModelFromUrl(mlRegisterModelInput, mlTask, modelVersion);
759768
}, e -> {
760769
log.error("Failed to register prebuilt model", e);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import static org.mockito.Mockito.times;
2222
import static org.mockito.Mockito.verify;
2323
import static org.mockito.Mockito.when;
24+
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
2425
import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES;
2526
import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH;
2627
import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES;
28+
import static org.opensearch.ml.model.MLModelManager.TIMEOUT_IN_MILLIS;
2729
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
2830
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
2931
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
@@ -49,8 +51,10 @@
4951
import java.net.URISyntaxException;
5052
import java.nio.charset.StandardCharsets;
5153
import java.nio.file.Path;
54+
import java.security.PrivilegedActionException;
5255
import java.util.Arrays;
5356
import java.util.Base64;
57+
import java.util.Collections;
5458
import java.util.HashMap;
5559
import java.util.List;
5660
import java.util.Map;
@@ -168,6 +172,9 @@ public class MLModelManagerTests extends OpenSearchTestCase {
168172
@Mock
169173
private ScriptService scriptService;
170174

175+
@Mock
176+
private MLTask pretrainedMLTask;
177+
171178
@Before
172179
public void setup() throws URISyntaxException {
173180
String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";
@@ -373,6 +380,35 @@ public void testRegisterMLModel_DownloadModelFileFailure() {
373380
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any());
374381
}
375382

383+
public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException {
384+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
385+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
386+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
387+
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
388+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
389+
MLRegisterModelInput pretrainedInput = mockPretrainedInput();
390+
doAnswer(invocation -> {
391+
ActionListener<MLRegisterModelInput> listener = (ActionListener<MLRegisterModelInput>) invocation.getArguments()[2];
392+
listener.onResponse(pretrainedInput);
393+
return null;
394+
}).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any());
395+
MLTask pretrainedTask = MLTask
396+
.builder()
397+
.taskId("pretrained")
398+
.modelId("pretrained")
399+
.functionName(FunctionName.TEXT_EMBEDDING)
400+
.build();
401+
modelManager.registerMLModel(pretrainedInput, pretrainedTask);
402+
assertEquals(pretrainedTask.getFunctionName(), FunctionName.SPARSE_ENCODING);
403+
verify(mlTaskManager)
404+
.updateMLTask(
405+
eq("pretrained"),
406+
eq(ImmutableMap.of(FUNCTION_NAME_FIELD, FunctionName.SPARSE_ENCODING)),
407+
eq((long) TIMEOUT_IN_MILLIS),
408+
eq(false)
409+
);
410+
}
411+
376412
@Ignore
377413
public void testRegisterMLModel_DownloadModelFile() throws IOException {
378414
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
@@ -916,4 +952,15 @@ private MLRegisterModelMetaInput prepareRequest() {
916952
.build();
917953
return input;
918954
}
955+
956+
private MLRegisterModelInput mockPretrainedInput() {
957+
return MLRegisterModelInput
958+
.builder()
959+
.modelName(modelName)
960+
.version(version)
961+
.modelGroupId("modelGroupId")
962+
.modelFormat(modelFormat)
963+
.functionName(FunctionName.SPARSE_ENCODING)
964+
.build();
965+
}
919966
}

0 commit comments

Comments
 (0)