Skip to content

Commit ab98f23

Browse files
add more UT for ml-algorithms (#182) (#184)
Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit 65a3f39) Co-authored-by: Yaliang Wu <[email protected]>
1 parent bf1d068 commit ab98f23

File tree

14 files changed

+507
-19
lines changed

14 files changed

+507
-19
lines changed

common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public class AnomalyDetectionParams implements MLAlgoParams {
4646
private Integer degree;
4747

4848

49-
@Builder
49+
@Builder(toBuilder = true)
5050
public AnomalyDetectionParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) {
5151
this.kernelType = kernelType;
5252
this.gamma = gamma;

common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class KMeansParams implements MLAlgoParams {
4242
private DistanceType distanceType;
4343
//TODO: expose number of thread and seed?
4444

45-
@Builder
45+
@Builder(toBuilder = true)
4646
public KMeansParams(Integer centroids, Integer iterations, DistanceType distanceType) {
4747
this.centroids = centroids;
4848
this.iterations = iterations;

common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public class LinearRegressionParams implements MLAlgoParams {
5959
private Long seed;
6060
private String target;
6161

62-
@Builder
62+
@Builder(toBuilder = true)
6363
public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Long seed, String target) {
6464
this.objectiveType = objectiveType;
6565
this.optimizerType = optimizerType;

ml-algorithms/build.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ jacocoTestCoverageVerification {
4242
rule {
4343
limit {
4444
counter = 'LINE'
45-
minimum = 0.5 //TODO: add more test to meet the coverage bar 0.7
45+
minimum = 0.94
4646
}
4747
limit {
4848
counter = 'BRANCH'
49-
minimum = 0.5 //TODO: add more test to meet the coverage bar 0.7
49+
minimum = 0.87
5050
}
5151
}
5252
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,9 @@
66
package org.opensearch.ml.engine.algorithms.sample;
77

88
import lombok.Data;
9-
import lombok.Getter;
109
import lombok.NoArgsConstructor;
11-
import lombok.Setter;
1210
import org.opensearch.client.Client;
13-
import org.opensearch.common.inject.Inject;
1411
import org.opensearch.common.settings.Settings;
15-
import org.opensearch.ml.common.dataframe.DataFrame;
1612
import org.opensearch.ml.common.parameter.FunctionName;
1713
import org.opensearch.ml.common.parameter.Input;
1814
import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput;
@@ -57,7 +53,7 @@ public Output execute(Input input) {
5753
double min = inputData.stream().min(Comparator.naturalOrder()).get();
5854
return new SampleAlgoOutput(min);
5955
default:
60-
throw new IllegalArgumentException("can't support this operation " + operation);
56+
throw new IllegalArgumentException("can't support this operation");
6157
}
6258
}
6359
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public SampleAlgo(MLAlgoParams parameters) {
3434
@Override
3535
public MLOutput predict(DataFrame dataFrame, Model model) {
3636
if (model == null) {
37-
throw new IllegalArgumentException("No model found for KMeans prediction.");
37+
throw new IllegalArgumentException("No model found for sample algo.");
3838
}
3939
AtomicReference<Double> sum = new AtomicReference<>((double) 0);
4040
dataFrame.forEach(row -> {

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,26 @@
1111
import org.junit.rules.ExpectedException;
1212
import org.mockito.MockedStatic;
1313
import org.mockito.Mockito;
14+
import org.opensearch.common.io.stream.StreamOutput;
15+
import org.opensearch.common.xcontent.XContentBuilder;
1416
import org.opensearch.ml.common.dataframe.DataFrame;
1517
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
1618
import org.opensearch.ml.common.dataset.MLInputDataset;
1719
import org.opensearch.ml.common.parameter.Input;
1820
import org.opensearch.ml.common.parameter.KMeansParams;
1921
import org.opensearch.ml.common.parameter.LinearRegressionParams;
2022
import org.opensearch.ml.common.parameter.FunctionName;
23+
import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput;
2124
import org.opensearch.ml.common.parameter.MLAlgoParams;
2225
import org.opensearch.ml.common.parameter.MLInput;
2326
import org.opensearch.ml.common.parameter.Model;
2427
import org.opensearch.ml.common.parameter.MLPredictionOutput;
28+
import org.opensearch.ml.common.parameter.SampleAlgoOutput;
2529

2630

31+
import java.io.IOException;
32+
import java.util.Arrays;
33+
2734
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
2835
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
2936
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
@@ -155,6 +162,55 @@ public void predictUnsupportedAlgorithm() {
155162
}
156163
}
157164

165+
@Test
166+
public void trainAndPredictWithKmeans() {
167+
int dataSize = 100;
168+
MLAlgoParams parameters = KMeansParams.builder().build();
169+
DataFrame dataFrame = constructKMeansDataFrame(dataSize);
170+
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
171+
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
172+
MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
173+
Assert.assertEquals(dataSize, output.getPredictionResult().size());
174+
}
175+
176+
@Test
177+
public void trainAndPredictWithInvalidInput() {
178+
exceptionRule.expect(IllegalArgumentException.class);
179+
exceptionRule.expectMessage("Input should be MLInput");
180+
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
181+
MLEngine.trainAndPredict(input);
182+
}
183+
184+
@Test
185+
public void executeLocalSampleCalculator() {
186+
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
187+
SampleAlgoOutput output = (SampleAlgoOutput) MLEngine.execute(input);
188+
Assert.assertEquals(3.0, output.getSampleResult(), 1e-5);
189+
}
190+
191+
@Test
192+
public void executeWithInvalidInput() {
193+
exceptionRule.expect(IllegalArgumentException.class);
194+
exceptionRule.expectMessage("Function name should not be null");
195+
Input input = new Input() {
196+
@Override
197+
public FunctionName getFunctionName() {
198+
return null;
199+
}
200+
201+
@Override
202+
public void writeTo(StreamOutput streamOutput) throws IOException {
203+
204+
}
205+
206+
@Override
207+
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
208+
return null;
209+
}
210+
};
211+
MLEngine.execute(input);
212+
}
213+
158214
private Model trainKMeansModel() {
159215
KMeansParams parameters = KMeansParams.builder()
160216
.centroids(2)

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVMTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,31 @@ public void train() {
107107
Assert.assertNotNull(model.getContent());
108108
}
109109

110+
@Test
111+
public void trainWithFullParams() {
112+
AnomalyDetectionParams parameters = AnomalyDetectionParams.builder().gamma(gamma).nu(nu).cost(1.0).coeff(0.01).epsilon(0.001).degree(1).kernelType(AnomalyDetectionParams.ADKernelType.LINEAR).build();
113+
AnomalyDetectionLibSVM anomalyDetection = new AnomalyDetectionLibSVM(parameters);
114+
Model model = anomalyDetection.train(trainDataFrame);
115+
Assert.assertEquals(FunctionName.AD_LIBSVM.name(), model.getName());
116+
Assert.assertEquals(AnomalyDetectionLibSVM.VERSION, model.getVersion());
117+
Assert.assertNotNull(model.getContent());
118+
119+
parameters = parameters.toBuilder().kernelType(AnomalyDetectionParams.ADKernelType.POLY).build();
120+
anomalyDetection = new AnomalyDetectionLibSVM(parameters);
121+
model = anomalyDetection.train(trainDataFrame);
122+
Assert.assertEquals(FunctionName.AD_LIBSVM.name(), model.getName());
123+
124+
parameters = parameters.toBuilder().kernelType(AnomalyDetectionParams.ADKernelType.RBF).build();
125+
anomalyDetection = new AnomalyDetectionLibSVM(parameters);
126+
model = anomalyDetection.train(trainDataFrame);
127+
Assert.assertEquals(FunctionName.AD_LIBSVM.name(), model.getName());
128+
129+
parameters = parameters.toBuilder().kernelType(AnomalyDetectionParams.ADKernelType.SIGMOID).build();
130+
anomalyDetection = new AnomalyDetectionLibSVM(parameters);
131+
model = anomalyDetection.train(trainDataFrame);
132+
Assert.assertEquals(FunctionName.AD_LIBSVM.name(), model.getName());
133+
}
134+
110135
@Test
111136
public void predict_NullModel() {
112137
exceptionRule.expect(IllegalArgumentException.class);
@@ -141,4 +166,19 @@ public void predict() {
141166
Assert.assertEquals(1.0, recall, 0.01);
142167
}
143168

169+
@Test
170+
public void constructor_NegativeGamma() {
171+
exceptionRule.expect(IllegalArgumentException.class);
172+
exceptionRule.expectMessage("gamma should be positive");
173+
AnomalyDetectionParams parameters = AnomalyDetectionParams.builder().gamma(-1.0).build();
174+
new AnomalyDetectionLibSVM(parameters);
175+
}
176+
177+
@Test
178+
public void constructor_NegativeNu() {
179+
exceptionRule.expect(IllegalArgumentException.class);
180+
exceptionRule.expectMessage("nu should be positive");
181+
AnomalyDetectionParams parameters = AnomalyDetectionParams.builder().nu(-1.0).build();
182+
new AnomalyDetectionLibSVM(parameters);
183+
}
144184
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/KMeansTest.java

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,21 @@
77

88
import org.junit.Assert;
99
import org.junit.Before;
10+
import org.junit.Rule;
1011
import org.junit.Test;
12+
import org.junit.rules.ExpectedException;
1113
import org.opensearch.ml.common.dataframe.DataFrame;
1214
import org.opensearch.ml.common.parameter.KMeansParams;
1315
import org.opensearch.ml.common.parameter.FunctionName;
1416
import org.opensearch.ml.common.parameter.MLPredictionOutput;
1517
import org.opensearch.ml.common.parameter.Model;
16-
import org.opensearch.ml.engine.algorithms.clustering.KMeans;
1718

1819
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
1920

2021

2122
public class KMeansTest {
23+
@Rule
24+
public ExpectedException exceptionRule = ExpectedException.none();
2225
private KMeansParams parameters;
2326
private KMeans kMeans;
2427
private DataFrame trainDataFrame;
@@ -48,6 +51,13 @@ public void predict() {
4851
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
4952
}
5053

54+
@Test
55+
public void predictWithNullModel() {
56+
exceptionRule.expect(IllegalArgumentException.class);
57+
exceptionRule.expectMessage("No model found for KMeans prediction");
58+
kMeans.predict(predictionDataFrame, null);
59+
}
60+
5161
@Test
5262
public void train() {
5363
Model model = kMeans.train(trainDataFrame);
@@ -56,6 +66,46 @@ public void train() {
5666
Assert.assertNotNull(model.getContent());
5767
}
5868

69+
@Test
70+
public void trainAndPredict() {
71+
KMeansParams parameters = KMeansParams.builder()
72+
.distanceType(KMeansParams.DistanceType.EUCLIDEAN)
73+
.iterations(10)
74+
.centroids(2)
75+
.build();
76+
KMeans kMeans = new KMeans(parameters);
77+
MLPredictionOutput output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
78+
DataFrame predictions = output.getPredictionResult();
79+
Assert.assertEquals(trainSize, predictions.size());
80+
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
81+
82+
parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.COSINE).build();
83+
kMeans = new KMeans(parameters);
84+
output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
85+
predictions = output.getPredictionResult();
86+
Assert.assertEquals(trainSize, predictions.size());
87+
88+
parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.L1).build();
89+
kMeans = new KMeans(parameters);
90+
output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
91+
predictions = output.getPredictionResult();
92+
Assert.assertEquals(trainSize, predictions.size());
93+
}
94+
95+
@Test
96+
public void constructorWithNegtiveCentroids() {
97+
exceptionRule.expect(IllegalArgumentException.class);
98+
exceptionRule.expectMessage("K should be positive");
99+
new KMeans(KMeansParams.builder().centroids(-1).build());
100+
}
101+
102+
@Test
103+
public void constructorWithNegtiveIterations() {
104+
exceptionRule.expect(IllegalArgumentException.class);
105+
exceptionRule.expectMessage("Iterations should be positive");
106+
new KMeans(KMeansParams.builder().iterations(-1).build());
107+
}
108+
59109
private void constructKMeansPredictionDataFrame() {
60110
predictionDataFrame = constructKMeansDataFrame(predictionSize);
61111
}

0 commit comments

Comments
 (0)