2525import org .opensearch .ml .common .input .parameter .regression .LinearRegressionParams ;
2626import org .opensearch .ml .common .FunctionName ;
2727import org .opensearch .ml .common .input .execute .samplecalculator .LocalSampleCalculatorInput ;
28+ import org .opensearch .ml .common .model .MLModelFormat ;
2829import org .opensearch .ml .common .output .execute .samplecalculator .LocalSampleCalculatorOutput ;
2930import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
3031import org .opensearch .ml .common .input .MLInput ;
3637import java .util .Arrays ;
3738import java .util .UUID ;
3839
40+ import static org .junit .Assert .assertEquals ;
3941import static org .opensearch .ml .engine .helper .LinearRegressionHelper .constructLinearRegressionPredictionDataFrame ;
4042import static org .opensearch .ml .engine .helper .LinearRegressionHelper .constructLinearRegressionTrainDataFrame ;
4143import static org .opensearch .ml .engine .helper .MLTestHelper .constructTestDataFrame ;
@@ -51,6 +53,17 @@ public void setUp() {
5153 mlEngine = new MLEngine (Path .of ("/tmp/test" + UUID .randomUUID ()));
5254 }
5355
56+ @ Test
57+ public void testPrebuiltModelPath () {
58+ String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b" ;
59+ String version = "1.0.1" ;
60+ MLModelFormat modelFormat = MLModelFormat .TORCH_SCRIPT ;
61+ String prebuiltModelPath = mlEngine .getPrebuiltModelPath (modelName , version , modelFormat );
62+ String prebuiltModelConfigPath = mlEngine .getPrebuiltModelConfigPath (modelName , version , modelFormat );
63+ assertEquals ("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-torch_script.zip" , prebuiltModelPath );
64+ assertEquals ("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json" , prebuiltModelConfigPath );
65+ }
66+
5467 @ Test
5568 public void predictKMeans () {
5669 MLModel model = trainKMeansModel ();
@@ -59,7 +72,7 @@ public void predictKMeans() {
5972 Input mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).inputDataset (inputDataset ).build ();
6073 MLPredictionOutput output = (MLPredictionOutput )mlEngine .predict (mlInput , model );
6174 DataFrame predictions = output .getPredictionResult ();
62- Assert . assertEquals (10 , predictions .size ());
75+ assertEquals (10 , predictions .size ());
6376 predictions .forEach (row -> Assert .assertTrue (row .getValue (0 ).intValue () == 0 || row .getValue (0 ).intValue () == 1 ));
6477 }
6578
@@ -71,7 +84,7 @@ public void predictLinearRegression() {
7184 Input mlInput = MLInput .builder ().algorithm (FunctionName .LINEAR_REGRESSION ).inputDataset (inputDataset ).build ();
7285 MLPredictionOutput output = (MLPredictionOutput )mlEngine .predict (mlInput , model );
7386 DataFrame predictions = output .getPredictionResult ();
74- Assert . assertEquals (2 , predictions .size ());
87+ assertEquals (2 , predictions .size ());
7588 }
7689
7790
@@ -83,7 +96,7 @@ public void loadLinearRegressionModel() {
8396 MLInputDataset inputDataset = DataFrameInputDataset .builder ().dataFrame (predictionDataFrame ).build ();
8497 MLPredictionOutput output = (MLPredictionOutput )predictor .predict (MLInput .builder ().algorithm (FunctionName .LINEAR_REGRESSION ).inputDataset (inputDataset ).build ());
8598 DataFrame predictions = output .getPredictionResult ();
86- Assert . assertEquals (2 , predictions .size ());
99+ assertEquals (2 , predictions .size ());
87100 }
88101
89102 @ Test
@@ -99,16 +112,16 @@ public void loadLinearRegressionModel_NullModel() {
99112 @ Test
100113 public void trainKMeans () {
101114 MLModel model = trainKMeansModel ();
102- Assert . assertEquals (FunctionName .KMEANS .name (), model .getName ());
103- Assert . assertEquals ("1.0.0" , model .getVersion ());
115+ assertEquals (FunctionName .KMEANS .name (), model .getName ());
116+ assertEquals ("1.0.0" , model .getVersion ());
104117 Assert .assertNotNull (model .getContent ());
105118 }
106119
107120 @ Test
108121 public void trainLinearRegression () {
109122 MLModel model = trainLinearRegressionModel ();
110- Assert . assertEquals (FunctionName .LINEAR_REGRESSION .name (), model .getName ());
111- Assert . assertEquals ("1.0.0" , model .getVersion ());
123+ assertEquals (FunctionName .LINEAR_REGRESSION .name (), model .getName ());
124+ assertEquals ("1.0.0" , model .getVersion ());
112125 Assert .assertNotNull (model .getContent ());
113126 }
114127
@@ -216,7 +229,7 @@ public void trainAndPredictWithKmeans() {
216229 MLInputDataset inputData = new DataFrameInputDataset (dataFrame );
217230 Input input = new MLInput (FunctionName .KMEANS , parameters , inputData );
218231 MLPredictionOutput output = (MLPredictionOutput ) mlEngine .trainAndPredict (input );
219- Assert . assertEquals (dataSize , output .getPredictionResult ().size ());
232+ assertEquals (dataSize , output .getPredictionResult ().size ());
220233 }
221234
222235 @ Test
@@ -231,7 +244,7 @@ public void trainAndPredictWithInvalidInput() {
231244 public void executeLocalSampleCalculator () {
232245 Input input = new LocalSampleCalculatorInput ("sum" , Arrays .asList (1.0 , 2.0 ));
233246 LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput ) mlEngine .execute (input );
234- Assert . assertEquals (3.0 , output .getResult (), 1e-5 );
247+ assertEquals (3.0 , output .getResult (), 1e-5 );
235248 }
236249
237250 @ Test
0 commit comments