Skip to content

Commit 0bc8740

Browse files
authored
fix verbose error message thrown by invalid enum (#167)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 313c448 commit 0bc8740

File tree

15 files changed

+139
-107
lines changed

15 files changed

+139
-107
lines changed

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ public class MachineLearningNodeClientTest {
6060
public ExpectedException exceptionRule = ExpectedException.none();
6161

6262
@Before
63-
public void setUp() throws Exception {
63+
public void setUp() {
6464
MockitoAnnotations.openMocks(this);
6565
}
6666

67+
@SuppressWarnings("unchecked")
6768
@Test
6869
public void predict() {
6970
doAnswer(invocation -> {
@@ -112,6 +113,7 @@ public void predict_Exception_WithNullDataSet() {
112113
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
113114
}
114115

116+
@SuppressWarnings("unchecked")
115117
@Test
116118
public void train() {
117119
String modelId = "test_model_id";

common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
8585

8686
switch (fieldName) {
8787
case KERNEL_FIELD:
88-
kernelType = ADKernelType.valueOf(parser.text().toUpperCase(Locale.ROOT));
88+
kernelType = ADKernelType.from(parser.text().toUpperCase(Locale.ROOT));
8989
break;
9090
case GAMMA_FIELD:
9191
gamma = parser.doubleValue();
@@ -171,6 +171,14 @@ public enum ADKernelType {
171171
LINEAR,
172172
POLY,
173173
RBF,
174-
SIGMOID
174+
SIGMOID;
175+
176+
public static ADKernelType from(String value) {
177+
try{
178+
return ADKernelType.valueOf(value);
179+
} catch (Exception e) {
180+
throw new IllegalArgumentException("Wrong AD kernel type");
181+
}
182+
}
175183
}
176184
}

common/src/main/java/org/opensearch/ml/common/parameter/FunctionName.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,13 @@ public enum FunctionName {
1313
LOCAL_SAMPLE_CALCULATOR,
1414
ANOMALY_LOCALIZATION,
1515
FIT_RCF,
16-
BATCH_RCF
16+
BATCH_RCF;
17+
18+
public static FunctionName from(String value) {
19+
try {
20+
return FunctionName.valueOf(value);
21+
} catch (Exception e) {
22+
throw new IllegalArgumentException("Wrong function name");
23+
}
24+
}
1725
}

common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
7676
iterations = parser.intValue();
7777
break;
7878
case DISTANCE_TYPE_FIELD:
79-
distanceType = DistanceType.valueOf(parser.text());
79+
distanceType = DistanceType.from(parser.text());
8080
break;
8181
default:
8282
parser.skipChildren();
@@ -127,6 +127,14 @@ public int getVersion() {
127127
public enum DistanceType {
128128
EUCLIDEAN,
129129
COSINE,
130-
L1
130+
L1;
131+
132+
public static DistanceType from(String value) {
133+
try {
134+
return DistanceType.valueOf(value);
135+
} catch (Exception e) {
136+
throw new IllegalArgumentException("Wrong distance type");
137+
}
138+
}
131139
}
132140
}

common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,27 @@ public int getVersion() {
257257
public enum ObjectiveType {
258258
SQUARED_LOSS,
259259
ABSOLUTE_LOSS,
260-
HUBER
260+
HUBER;
261+
public static ObjectiveType from(String value) {
262+
try{
263+
return ObjectiveType.valueOf(value);
264+
} catch (Exception e) {
265+
throw new IllegalArgumentException("Wrong objective type");
266+
}
267+
}
261268
}
262269

263270
public enum MomentumType {
264271
STANDARD,
265-
NESTEROV
272+
NESTEROV;
273+
274+
public static MomentumType from(String value) {
275+
try{
276+
return MomentumType.valueOf(value);
277+
} catch (Exception e) {
278+
throw new IllegalArgumentException("Wrong momentum type");
279+
}
280+
}
266281
}
267282

268283
public enum OptimizerType {
@@ -272,6 +287,14 @@ public enum OptimizerType {
272287
ADA_GRAD,
273288
ADA_DELTA,
274289
ADAM,
275-
RMS_PROP
290+
RMS_PROP;
291+
292+
public static OptimizerType from(String value) {
293+
try{
294+
return OptimizerType.valueOf(value);
295+
} catch (Exception e) {
296+
throw new IllegalArgumentException("Wrong optimizer type");
297+
}
298+
}
276299
}
277300
}

common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
132132

133133
public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
134134
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
135-
FunctionName algorithm = FunctionName.valueOf(algorithmName);
135+
FunctionName algorithm = FunctionName.from(algorithmName);
136136
MLAlgoParams mlParameters = null;
137137
SearchSourceBuilder searchSourceBuilder = null;
138138
List<String> sourceIndices = new ArrayList<>();

common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
119119
user = User.parse(parser);
120120
break;
121121
case ALGORITHM:
122-
algorithm = FunctionName.valueOf(parser.text());
122+
algorithm = FunctionName.from(parser.text());
123123
break;
124124
default:
125125
parser.skipChildren();

common/src/main/java/org/opensearch/ml/common/parameter/MLTask.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
220220
taskType = MLTaskType.valueOf(parser.text());
221221
break;
222222
case FUNCTION_NAME_FIELD:
223-
functionName = FunctionName.valueOf(parser.text());
223+
functionName = FunctionName.from(parser.text());
224224
break;
225225
case STATE_FIELD:
226226
state = MLTaskState.valueOf(parser.text());

common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public void load_Exception_EmptyInputMapList() {
6161

6262
@Test(expected = IllegalArgumentException.class)
6363
public void load_Exception_NullInputMapList() {
64-
DataFrameBuilder.load((List) null);
64+
DataFrameBuilder.load((List<Map<String, Object>>) null);
6565
}
6666

6767
@Test

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/TribuoUtil.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import lombok.experimental.UtilityClass;
99
import org.apache.commons.lang3.StringUtils;
1010
import org.opensearch.common.collect.Tuple;
11+
import org.opensearch.ml.common.dataframe.ColumnMeta;
12+
import org.opensearch.ml.common.dataframe.ColumnValue;
1113
import org.opensearch.ml.common.dataframe.DataFrame;
1214
import org.opensearch.ml.common.dataframe.Row;
1315
import org.opensearch.ml.engine.contants.TribuoOutputType;
@@ -31,18 +33,18 @@
3133

3234
@UtilityClass
3335
public class TribuoUtil {
34-
public static Tuple transformDataFrame(DataFrame dataFrame) {
35-
String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(e -> e.getName()).toArray(String[]::new);
36+
public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
37+
String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
3638
double[][] featureValues = new double[dataFrame.size()][];
3739
Iterator<Row> itr = dataFrame.iterator();
3840
int i = 0;
3941
while (itr.hasNext()) {
4042
Row row = itr.next();
41-
featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(e -> e.doubleValue()).toArray();
43+
featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
4244
++i;
4345
}
4446

45-
return new Tuple(featureNames, featureValues);
47+
return new Tuple<>(featureNames, featureValues);
4648
}
4749

4850
/**
@@ -60,19 +62,19 @@ public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame
6062
for (int i=0; i<dataFrame.size(); ++i) {
6163
switch (outputType) {
6264
case CLUSTERID:
63-
example = new ArrayExample<T>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
65+
example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
6466
break;
6567
case REGRESSOR:
6668
//Create single dimension tribuo regressor with name DIM-0 and value double NaN.
67-
example = new ArrayExample<T>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
69+
example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
6870
break;
6971
case ANOMALY_DETECTION_LIBSVM:
7072
// Why we set default event type as EXPECTED(non-anomalous)
7173
// 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
7274
// 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
7375
Event.EventType defaultEventType = Event.EventType.EXPECTED;
7476
// TODO: support anomaly labels to evaluate prediction result
75-
example = new ArrayExample<T>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
77+
example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
7678
break;
7779
default:
7880
throw new IllegalArgumentException("unknown type:" + outputType);
@@ -127,7 +129,7 @@ public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(
127129
filter(e -> e != finalTargetIndex).
128130
mapToDouble(e -> featureNamesValues.v2()[finalI][e]).
129131
toArray();
130-
example = new ArrayExample<T>((T) new Regressor(target, targetValue), featureNames, featureValues);
132+
example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
131133
break;
132134
default:
133135
throw new IllegalArgumentException("unknown type:" + outputType);

0 commit comments

Comments
 (0)