Skip to content

Commit dc7a0b7

Browse files
authored
enable prebuilt model (#729)
* enable prebuilt model Signed-off-by: Yaliang Wu <[email protected]> * address comments Signed-off-by: Yaliang Wu <[email protected]> * add unit test for prebuilt model url Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent ffb8a4e commit dc7a0b7

File tree

7 files changed

+44
-39
lines changed

7 files changed

+44
-39
lines changed

build.gradle

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ ext {
7171
noticeFile = rootProject.file('NOTICE.txt')
7272
}
7373

74-
dependencies {
75-
implementation 'junit:junit:${versions.junit}'
76-
}
77-
7874
// updateVersion: Task to auto increment to the next development iteration
7975
task updateVersion {
8076
onlyIf { System.getProperty('newVersion') }

common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ public String getName() {
184184

185185
public static PoolingMode from(String value) {
186186
try {
187-
return PoolingMode.valueOf(value);
187+
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
188188
} catch (Exception e) {
189189
throw new IllegalArgumentException("Wrong pooling method");
190190
}
@@ -197,7 +197,7 @@ public enum FrameworkType {
197197

198198
public static FrameworkType from(String value) {
199199
try {
200-
return FrameworkType.valueOf(value);
200+
return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
201201
} catch (Exception e) {
202202
throw new IllegalArgumentException("Wrong framework type");
203203
}

common/src/main/java/org/opensearch/ml/common/transport/upload/MLUploadInput.java

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,12 @@ public MLUploadInput(FunctionName functionName,
7373
if (version == null) {
7474
throw new IllegalArgumentException("model version is null");
7575
}
76-
//TODO: enable prebuilt model in 2.6
77-
// if (url != null) {
78-
// if (modelFormat == null) {
79-
// throw new IllegalArgumentException("model format is null");
80-
// }
81-
// if (modelConfig == null) {
82-
// throw new IllegalArgumentException("model config is null");
83-
// }
84-
// }
8576
if (modelFormat == null) {
8677
throw new IllegalArgumentException("model format is null");
8778
}
88-
if (modelConfig == null) {
79+
if (url != null && modelConfig == null) {
8980
throw new IllegalArgumentException("model config is null");
9081
}
91-
if (url == null) {
92-
throw new IllegalArgumentException("model file url is null");
93-
}
9482
this.modelName = modelName;
9583
this.version = version;
9684
this.description = description;

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.ml.common.input.Input;
1414
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
1515
import org.opensearch.ml.common.input.MLInput;
16+
import org.opensearch.ml.common.model.MLModelFormat;
1617
import org.opensearch.ml.common.output.MLOutput;
1718
import org.opensearch.ml.common.output.Output;
1819

@@ -25,6 +26,8 @@
2526
*/
2627
public class MLEngine {
2728

29+
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";
30+
2831
@Getter
2932
private final Path djlCachePath;
3033
private final Path djlModelsCachePath;
@@ -34,13 +37,18 @@ public MLEngine(Path opensearchDataFolder) {
3437
djlModelsCachePath = djlCachePath.resolve("models_cache");
3538
}
3639

37-
public String getCIPrebuiltModelConfigPath(String modelName, String version) {
38-
return String.format("https://ci.opensearch.org/ci/dbc/models/ml-models/%s/%s/config.json", modelName, version, Locale.ROOT);
40+
public String getPrebuiltModelConfigPath(String modelName, String version, MLModelFormat modelFormat) {
41+
String format = modelFormat.name().toLowerCase(Locale.ROOT);
42+
return String.format("%s/%s/%s/%s/config.json", MODEL_REPO, modelName, version, format, Locale.ROOT);
3943
}
4044

41-
public String getCIPrebuiltModelPath(String modelName, String version) {
42-
int index = modelName.lastIndexOf("/") + 1;
43-
return String.format("https://ci.opensearch.org/ci/dbc/models/ml-models/%s/%s/%s.zip", modelName, version, modelName.substring(index), Locale.ROOT);
45+
public String getPrebuiltModelPath(String modelName, String version, MLModelFormat modelFormat) {
46+
int index = modelName.indexOf("/") + 1;
47+
// /huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.0/onnx/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.0-torch_script.zip
48+
// /huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.0/onnx/config.json
49+
String format = modelFormat.name().toLowerCase(Locale.ROOT);
50+
String modelZipFileName = modelName.substring(index).replace("/", "_") + "-" + version + "-" + format;
51+
return String.format("%s/%s/%s/%s/%s.zip", MODEL_REPO, modelName, version, format, modelZipFileName, Locale.ROOT);
4452
}
4553

4654
public Path getUploadModelPath(String modelId, String modelName, String version) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Enumeration;
2626
import java.util.HashMap;
2727
import java.util.List;
28+
import java.util.Locale;
2829
import java.util.Map;
2930
import java.util.zip.ZipEntry;
3031
import java.util.zip.ZipFile;
@@ -55,6 +56,7 @@ public ModelHelper(MLEngine mlEngine) {
5556
public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput, ActionListener<MLUploadInput> listener) {
5657
String modelName = uploadInput.getModelName();
5758
String version = uploadInput.getVersion();
59+
MLModelFormat modelFormat = uploadInput.getModelFormat();
5860
boolean loadModel = uploadInput.isLoadModel();
5961
String[] modelNodeIds = uploadInput.getModelNodeIds();
6062
try {
@@ -63,8 +65,8 @@ public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput
6365
Path modelUploadPath = mlEngine.getUploadModelPath(taskId, modelName, version);
6466
String configCacheFilePath = modelUploadPath.resolve("config.json").toString();
6567

66-
String configFileUrl = mlEngine.getCIPrebuiltModelConfigPath(modelName, version);
67-
String modelZipFileUrl = mlEngine.getCIPrebuiltModelPath(modelName, version);
68+
String configFileUrl = mlEngine.getPrebuiltModelConfigPath(modelName, version, modelFormat);
69+
String modelZipFileUrl = mlEngine.getPrebuiltModelPath(modelName, version, modelFormat);
6870
DownloadUtils.download(configFileUrl, configCacheFilePath, new ProgressBar());
6971

7072
Map<?, ?> config = null;
@@ -103,7 +105,7 @@ public void downloadPrebuiltModelConfig(String taskId, MLUploadInput uploadInput
103105
configBuilder.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
104106
break;
105107
case TextEmbeddingModelConfig.POOLING_MODE_FIELD:
106-
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString()));
108+
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)));
107109
break;
108110
case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD:
109111
configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
2626
import org.opensearch.ml.common.FunctionName;
2727
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
28+
import org.opensearch.ml.common.model.MLModelFormat;
2829
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
2930
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
3031
import org.opensearch.ml.common.input.MLInput;
@@ -36,6 +37,7 @@
3637
import java.util.Arrays;
3738
import java.util.UUID;
3839

40+
import static org.junit.Assert.assertEquals;
3941
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
4042
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
4143
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
@@ -51,6 +53,17 @@ public void setUp() {
5153
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()));
5254
}
5355

56+
@Test
57+
public void testPrebuiltModelPath() {
58+
String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b";
59+
String version = "1.0.1";
60+
MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
61+
String prebuiltModelPath = mlEngine.getPrebuiltModelPath(modelName, version, modelFormat);
62+
String prebuiltModelConfigPath = mlEngine.getPrebuiltModelConfigPath(modelName, version, modelFormat);
63+
assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-torch_script.zip", prebuiltModelPath);
64+
assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json", prebuiltModelConfigPath);
65+
}
66+
5467
@Test
5568
public void predictKMeans() {
5669
MLModel model = trainKMeansModel();
@@ -59,7 +72,7 @@ public void predictKMeans() {
5972
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
6073
MLPredictionOutput output = (MLPredictionOutput)mlEngine.predict(mlInput, model);
6174
DataFrame predictions = output.getPredictionResult();
62-
Assert.assertEquals(10, predictions.size());
75+
assertEquals(10, predictions.size());
6376
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
6477
}
6578

@@ -71,7 +84,7 @@ public void predictLinearRegression() {
7184
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
7285
MLPredictionOutput output = (MLPredictionOutput)mlEngine.predict(mlInput, model);
7386
DataFrame predictions = output.getPredictionResult();
74-
Assert.assertEquals(2, predictions.size());
87+
assertEquals(2, predictions.size());
7588
}
7689

7790

@@ -83,7 +96,7 @@ public void loadLinearRegressionModel() {
8396
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
8497
MLPredictionOutput output = (MLPredictionOutput)predictor.predict(MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build());
8598
DataFrame predictions = output.getPredictionResult();
86-
Assert.assertEquals(2, predictions.size());
99+
assertEquals(2, predictions.size());
87100
}
88101

89102
@Test
@@ -99,16 +112,16 @@ public void loadLinearRegressionModel_NullModel() {
99112
@Test
100113
public void trainKMeans() {
101114
MLModel model = trainKMeansModel();
102-
Assert.assertEquals(FunctionName.KMEANS.name(), model.getName());
103-
Assert.assertEquals("1.0.0", model.getVersion());
115+
assertEquals(FunctionName.KMEANS.name(), model.getName());
116+
assertEquals("1.0.0", model.getVersion());
104117
Assert.assertNotNull(model.getContent());
105118
}
106119

107120
@Test
108121
public void trainLinearRegression() {
109122
MLModel model = trainLinearRegressionModel();
110-
Assert.assertEquals(FunctionName.LINEAR_REGRESSION.name(), model.getName());
111-
Assert.assertEquals("1.0.0", model.getVersion());
123+
assertEquals(FunctionName.LINEAR_REGRESSION.name(), model.getName());
124+
assertEquals("1.0.0", model.getVersion());
112125
Assert.assertNotNull(model.getContent());
113126
}
114127

@@ -216,7 +229,7 @@ public void trainAndPredictWithKmeans() {
216229
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
217230
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
218231
MLPredictionOutput output = (MLPredictionOutput) mlEngine.trainAndPredict(input);
219-
Assert.assertEquals(dataSize, output.getPredictionResult().size());
232+
assertEquals(dataSize, output.getPredictionResult().size());
220233
}
221234

222235
@Test
@@ -231,7 +244,7 @@ public void trainAndPredictWithInvalidInput() {
231244
public void executeLocalSampleCalculator() {
232245
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
233246
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) mlEngine.execute(input);
234-
Assert.assertEquals(3.0, output.getResult(), 1e-5);
247+
assertEquals(3.0, output.getResult(), 1e-5);
235248
}
236249

237250
@Test

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ public void uploadMLModel(MLUploadInput uploadInput, MLTask mlTask) {
195195
if (uploadInput.getUrl() != null) {
196196
uploadModelFromUrl(uploadInput, mlTask);
197197
} else {
198-
throw new IllegalArgumentException("model file URL is null");
199-
// TODO: support prebuilt model later
200-
// uploadPrebuiltModel(uploadInput, mlTask);
198+
uploadPrebuiltModel(uploadInput, mlTask);
201199
}
202200
} catch (Exception e) {
203201
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), UPLOAD, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();

0 commit comments

Comments
 (0)