Skip to content

Commit 552ad06

Browse files
more strict check on input parameters by applying non-coerce mode (#173) (#175)
Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit c50e0d6) Co-authored-by: Yaliang Wu <[email protected]>
1 parent e6d71f7 commit 552ad06

File tree

12 files changed

+102
-38
lines changed

12 files changed

+102
-38
lines changed

common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
8888
kernelType = ADKernelType.from(parser.text().toUpperCase(Locale.ROOT));
8989
break;
9090
case GAMMA_FIELD:
91-
gamma = parser.doubleValue();
91+
gamma = parser.doubleValue(false);
9292
break;
9393
case NU_FIELD:
94-
nu = parser.doubleValue();
94+
nu = parser.doubleValue(false);
9595
break;
9696
case COST_FIELD:
97-
cost = parser.doubleValue();
97+
cost = parser.doubleValue(false);
9898
break;
9999
case COEFF_FIELD:
100-
coeff = parser.doubleValue();
100+
coeff = parser.doubleValue(false);
101101
break;
102102
case EPSILON_FIELD:
103-
epsilon = parser.doubleValue();
103+
epsilon = parser.doubleValue(false);
104104
break;
105105
case DEGREE_FIELD:
106-
degree = parser.intValue();
106+
degree = parser.intValue(false);
107107
break;
108108
default:
109109
parser.skipChildren();

common/src/main/java/org/opensearch/ml/common/parameter/BatchRCFParams.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,22 @@ public static BatchRCFParams parse(XContentParser parser) throws IOException {
9191

9292
switch (fieldName) {
9393
case NUMBER_OF_TREES:
94-
numberOfTrees = parser.intValue();
94+
numberOfTrees = parser.intValue(false);
9595
break;
9696
case SHINGLE_SIZE:
97-
shingleSize = parser.intValue();
97+
shingleSize = parser.intValue(false);
9898
break;
9999
case SAMPLE_SIZE:
100-
sampleSize = parser.intValue();
100+
sampleSize = parser.intValue(false);
101101
break;
102102
case OUTPUT_AFTER:
103-
outputAfter = parser.intValue();
103+
outputAfter = parser.intValue(false);
104104
break;
105105
case TRAINING_DATA_SIZE:
106-
trainingDataSize = parser.intValue();
106+
trainingDataSize = parser.intValue(false);
107107
break;
108108
case ANOMALY_SCORE_THRESHOLD:
109-
anomalyScoreThreshold = parser.doubleValue();
109+
anomalyScoreThreshold = parser.doubleValue(false);
110110
break;
111111
default:
112112
parser.skipChildren();

common/src/main/java/org/opensearch/ml/common/parameter/FitRCFParams.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,22 @@ public static FitRCFParams parse(XContentParser parser) throws IOException {
112112

113113
switch (fieldName) {
114114
case NUMBER_OF_TREES:
115-
numberOfTrees = parser.intValue();
115+
numberOfTrees = parser.intValue(false);
116116
break;
117117
case SHINGLE_SIZE:
118-
shingleSize = parser.intValue();
118+
shingleSize = parser.intValue(false);
119119
break;
120120
case SAMPLE_SIZE:
121-
sampleSize = parser.intValue();
121+
sampleSize = parser.intValue(false);
122122
break;
123123
case OUTPUT_AFTER:
124-
outputAfter = parser.intValue();
124+
outputAfter = parser.intValue(false);
125125
break;
126126
case TIME_DECAY:
127-
timeDecay = parser.doubleValue();
127+
timeDecay = parser.doubleValue(false);
128128
break;
129129
case ANOMALY_RATE:
130-
anomalyRate = parser.doubleValue();
130+
anomalyRate = parser.doubleValue(false);
131131
break;
132132
case TIME_FIELD:
133133
timeField = parser.text();

common/src/main/java/org/opensearch/ml/common/parameter/KMeansParams.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import lombok.Builder;
99
import lombok.Data;
10-
import lombok.Getter;
1110
import org.opensearch.common.ParseField;
1211
import org.opensearch.common.io.stream.StreamInput;
1312
import org.opensearch.common.io.stream.StreamOutput;
@@ -70,10 +69,10 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
7069

7170
switch (fieldName) {
7271
case CENTROIDS_FIELD:
73-
k = parser.intValue();
72+
k = parser.intValue(false);
7473
break;
7574
case ITERATIONS_FIELD:
76-
iterations = parser.intValue();
75+
iterations = parser.intValue(false);
7776
break;
7877
case DISTANCE_TYPE_FIELD:
7978
distanceType = DistanceType.from(parser.text());

common/src/main/java/org/opensearch/ml/common/parameter/LinearRegressionParams.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,34 +126,34 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
126126
optimizerType = OptimizerType.valueOf(parser.text().toUpperCase(Locale.ROOT));
127127
break;
128128
case LEARNING_RATE_FIELD:
129-
learningRate = parser.doubleValue();
129+
learningRate = parser.doubleValue(false);
130130
break;
131131
case MOMENTUM_TYPE_FIELD:
132132
momentumType = MomentumType.valueOf(parser.text().toUpperCase(Locale.ROOT));
133133
break;
134134
case MOMENTUM_FACTOR_FIELD:
135-
momentumFactor = parser.doubleValue();
135+
momentumFactor = parser.doubleValue(false);
136136
break;
137137
case EPSILON_FIELD:
138-
epsilon = parser.doubleValue();
138+
epsilon = parser.doubleValue(false);
139139
break;
140140
case BETA1_FIELD:
141-
beta1 = parser.doubleValue();
141+
beta1 = parser.doubleValue(false);
142142
break;
143143
case BETA2_FIELD:
144-
beta2 = parser.doubleValue();
144+
beta2 = parser.doubleValue(false);
145145
break;
146146
case DECAY_RATE_FIELD:
147-
decayRate = parser.doubleValue();
147+
decayRate = parser.doubleValue(false);
148148
break;
149149
case EPOCHS_FIELD:
150-
epochs = parser.intValue();
150+
epochs = parser.intValue(false);
151151
break;
152152
case BATCH_SIZE_FIELD:
153-
batchSize = parser.intValue();
153+
batchSize = parser.intValue(false);
154154
break;
155155
case SEED_FIELD:
156-
seed = parser.longValue();
156+
seed = parser.longValue(false);
157157
break;
158158
case TARGET_FIELD:
159159
target = parser.text();

common/src/main/java/org/opensearch/ml/common/parameter/LocalSampleCalculatorInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public static LocalSampleCalculatorInput parse(XContentParser parser) throws IOE
4848
case INPUT_DATA_FIELD:
4949
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
5050
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
51-
inputData.add(parser.doubleValue());
51+
inputData.add(parser.doubleValue(false));
5252
}
5353
break;
5454
default:

common/src/main/java/org/opensearch/ml/common/parameter/MLModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
113113
content = parser.text();
114114
break;
115115
case MODEL_VERSION:
116-
version = parser.intValue();
116+
version = parser.intValue(false);
117117
break;
118118
case USER:
119119
user = User.parse(parser);

common/src/main/java/org/opensearch/ml/common/parameter/SampleAlgoParams.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
5151

5252
switch (fieldName) {
5353
case SAMPLE_PARAM_FIELD:
54-
sampleParam = parser.intValue();
54+
sampleParam = parser.intValue(false);
5555
break;
5656
default:
5757
parser.skipChildren();

common/src/test/java/org/opensearch/ml/common/TestHelper.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ public static <T> void testParseFromString(ToXContentObject obj, String jsonStr,
4545
obj.equals(parsedObj);
4646
}
4747

48+
public static String contentObjectToString(ToXContentObject obj) throws IOException {
49+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
50+
obj.toXContent(builder, ToXContent.EMPTY_PARAMS);
51+
return xContentBuilderToString(builder);
52+
}
53+
4854
public static String xContentBuilderToString(XContentBuilder builder) {
4955
return BytesReference.bytes(builder).utf8ToString();
5056
}

common/src/test/java/org/opensearch/ml/common/parameter/KMeansParamsTest.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package org.opensearch.ml.common.parameter;
77

88
import org.junit.Before;
9+
import org.junit.Rule;
910
import org.junit.Test;
11+
import org.junit.rules.ExpectedException;
1012
import org.opensearch.common.io.stream.BytesStreamOutput;
1113
import org.opensearch.common.io.stream.StreamInput;
1214
import org.opensearch.common.xcontent.XContentParser;
@@ -16,8 +18,12 @@
1618
import java.util.function.Function;
1719

1820
import static org.junit.Assert.assertEquals;
21+
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
22+
import static org.opensearch.ml.common.TestHelper.testParseFromString;
1923

2024
public class KMeansParamsTest {
25+
@Rule
26+
public ExpectedException exceptionRule = ExpectedException.none();
2127

2228
KMeansParams params;
2329
private Function<XContentParser, KMeansParams> function = parser -> {
@@ -42,6 +48,22 @@ public void parse_KMeansParams() throws IOException {
4248
TestHelper.testParse(params, function);
4349
}
4450

51+
@Test
52+
public void parse_KMeansParams_InvalidDoubleValue() throws IOException {
53+
exceptionRule.expect(IllegalArgumentException.class);
54+
exceptionRule.expectMessage("10.01 cannot be converted to Integer without data loss");
55+
String paramsStr = contentObjectToString(params);
56+
testParseFromString(params, paramsStr.replace("\"iterations\":10,", "\"iterations\":10.01,"), function);
57+
}
58+
59+
@Test
60+
public void parse_KMeansParams_InvalidDoubleString() throws IOException {
61+
exceptionRule.expect(IllegalArgumentException.class);
62+
exceptionRule.expectMessage("Integer value passed as String");
63+
String paramsStr = contentObjectToString(params);
64+
testParseFromString(params, paramsStr.replace("\"iterations\":10,", "\"iterations\":\"10.01\","), function);
65+
}
66+
4567
@Test
4668
public void parse_EmptyKMeansParams() throws IOException {
4769
TestHelper.testParse(KMeansParams.builder().build(), function);

0 commit comments

Comments
 (0)