Skip to content

Commit 3e33f0d

Browse files
Add an array filter to our serialize/deserialize methods and narrow down previous filter (#849) (#857)
* Add an array filter to our serialize/deserialize methods and narrow down previous filter Signed-off-by: Sicheng Song <[email protected]> * Further narrowing down accept list Signed-off-by: Sicheng Song <[email protected]> * Keep narrowing down accept list Signed-off-by: Sicheng Song <[email protected]> * Add test for deserialization methods in all built-in models Signed-off-by: Sicheng Song <[email protected]> --------- Signed-off-by: Sicheng Song <[email protected]> (cherry picked from commit 0997d6c) Co-authored-by: Sicheng Song <[email protected]>
1 parent c5beb06 commit 3e33f0d

File tree

5 files changed

+76
-6
lines changed

5 files changed

+76
-6
lines changed

ml-algorithms/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ repositories {
1818
dependencies {
1919
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
2020
implementation project(':opensearch-ml-common')
21+
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
22+
implementation "org.opensearch:common-utils:${common_utils_version}"
2123
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
2224
implementation group: 'org.reflections', name: 'reflections', version: '0.9.12'
2325
implementation group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.1'
@@ -35,6 +37,7 @@ dependencies {
3537
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
3638
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0'
3739
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
40+
implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.1'
3841
implementation platform("ai.djl:bom:0.19.0")
3942
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
4043
implementation group: 'ai.djl', name: 'api'

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,56 @@
1414
import java.io.ByteArrayOutputStream;
1515
import java.io.IOException;
1616
import java.io.ObjectOutputStream;
17+
import java.io.ObjectInputFilter;
1718
import java.util.Base64;
1819

1920
@UtilityClass
2021
public class ModelSerDeSer {
21-
// Welcome list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries.
22+
// Accept list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries.
2223
public static final String[] ACCEPT_CLASS_PATTERNS = {
2324
"java.lang.*",
2425
"java.util.*",
2526
"java.time.*",
27+
"org.tribuo.*",
28+
"com.oracle.labs.mlrg.olcut.provenance.*",
29+
"com.oracle.labs.mlrg.olcut.util.*",
30+
"[I",
31+
"[Z",
32+
"[J",
33+
"[C",
34+
"[D",
35+
"[F",
36+
"[Ljava.lang.*",
37+
"[Lorg.tribuo.*",
38+
"[Llibsvm.*",
39+
"[[I",
40+
"[[Z",
41+
"[[J",
42+
"[[C",
43+
"[[D",
44+
"[[F",
45+
"[[Ljava.lang.*",
46+
"[[Lorg.tribuo.*",
47+
"[[Llibsvm.*",
2648
"org.opensearch.ml.*",
27-
"*org.tribuo.*",
2849
"libsvm.*",
29-
"com.oracle.labs.*",
30-
"[*",
31-
"com.amazon.randomcutforest.*"
50+
};
51+
52+
public static final String[] REJECT_CLASS_PATTERNS = {
53+
"java.util.logging.*",
54+
"java.util.zip.*",
55+
"java.util.jar.*",
56+
"java.util.random.*",
57+
"java.util.spi.*",
58+
"java.util.stream.*",
59+
"java.util.regex.*",
60+
"java.util.concurrent.*",
61+
"java.util.function.*",
62+
"java.util.prefs.*",
63+
"java.time.zone.*",
64+
"java.time.format.*",
65+
"java.time.temporal.*",
66+
"java.time.chrono.*",
3267
};
3368

3469
public static String serializeToBase64(Object model) {
@@ -47,11 +82,15 @@ public static byte[] serialize(Object model) {
4782
}
4883
}
4984

85+
// This method has been tested in K-means, Linear Regression, Logistic regression, Anomaly Detection and Random Cut Forest summarization and passed.
5086
public static Object deserialize(byte[] modelBin) {
5187
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin);
5288
ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){
5389
// Validate the model class type to avoid deserialization attack.
54-
validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS);
90+
validatingObjectInputStream
91+
.accept(ACCEPT_CLASS_PATTERNS)
92+
.reject(REJECT_CLASS_PATTERNS)
93+
.setObjectInputFilter(ObjectInputFilter.Config.createFilter("maxdepth=20;maxrefs=5000;maxbytes=10000000;maxarray=100000"));
5594
return validatingObjectInputStream.readObject();
5695
} catch (IOException | ClassNotFoundException e) {
5796
throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause());

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVMTest.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
import org.opensearch.ml.common.input.parameter.ad.AnomalyDetectionLibSVMParams;
2525
import org.opensearch.ml.common.FunctionName;
2626
import org.opensearch.ml.common.output.MLPredictionOutput;
27+
import org.opensearch.ml.engine.utils.ModelSerDeSer;
2728
import org.tribuo.Dataset;
2829
import org.tribuo.Example;
2930
import org.tribuo.Feature;
3031
import org.tribuo.anomaly.Event;
3132
import org.tribuo.anomaly.example.AnomalyDataGenerator;
33+
import org.tribuo.common.libsvm.LibSVMModel;
3234

3335
import java.util.ArrayList;
3436
import java.util.Iterator;
@@ -117,6 +119,13 @@ public void train() {
117119
Assert.assertNotNull(model.getContent());
118120
}
119121

122+
@Test
123+
public void testModelSerDeSer() {
124+
MLModel model = anomalyDetection.train(trainDataFrameInput);
125+
LibSVMModel deserializedModel = (LibSVMModel) ModelSerDeSer.deserialize(model);
126+
Assert.assertNotNull(deserializedModel);
127+
}
128+
120129
@Test
121130
public void trainWithFullParams() {
122131
AnomalyDetectionLibSVMParams parameters = AnomalyDetectionLibSVMParams.builder().gamma(gamma).nu(nu).cost(1.0).coeff(0.01).epsilon(0.001).degree(1).kernelType(AnomalyDetectionLibSVMParams.ADKernelType.LINEAR).build();

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarizeTest.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.algorithms.clustering;
77

8+
import com.amazon.randomcutforest.returntypes.SampleSummary;
89
import org.junit.Assert;
910
import org.junit.Before;
1011
import org.junit.Rule;
@@ -17,6 +18,7 @@
1718
import org.opensearch.ml.common.input.parameter.clustering.RCFSummarizeParams;
1819
import org.opensearch.ml.common.FunctionName;
1920
import org.opensearch.ml.common.output.MLPredictionOutput;
21+
import org.opensearch.ml.engine.utils.ModelSerDeSer;
2022

2123
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
2224

@@ -61,6 +63,13 @@ public void predictWithTrivalModelExpectBoNorminalOutput() {
6163
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
6264
}
6365

66+
@Test
67+
public void testModelSerDeSer() {
68+
MLModel model = rcfSummarize.train(trainDataFrameInput);
69+
SampleSummary deserializedModel = ((SerializableSummary) ModelSerDeSer.deserialize(model)).getSummary();
70+
Assert.assertNotNull(deserializedModel);
71+
}
72+
6473
@Test
6574
public void trainAndPredictWithRegularInputExpectNotNullOutput() {
6675
RCFSummarizeParams parameters = RCFSummarizeParams.builder()

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegressionTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import org.opensearch.ml.common.input.MLInput;
1818
import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams;
1919
import org.opensearch.ml.common.output.MLPredictionOutput;
20+
import org.opensearch.ml.engine.utils.ModelSerDeSer;
21+
import org.tribuo.classification.Label;
2022

2123
import static org.opensearch.ml.engine.helper.LogisticRegressionHelper.constructLogisticRegressionPredictionDataFrame;
2224
import static org.opensearch.ml.engine.helper.LogisticRegressionHelper.constructLogisticRegressionTrainDataFrame;
@@ -109,6 +111,14 @@ public void predict() {
109111
Assert.assertEquals(2, predictions.size());
110112
}
111113

114+
@Test
115+
public void testModelSerDeSer() {
116+
LogisticRegression classification = new LogisticRegression(parameters);
117+
MLModel model = classification.train(trainDataFrameInput);
118+
org.tribuo.Model<Label> deserializedModel = (org.tribuo.Model<Label>) ModelSerDeSer.deserialize(model);
119+
Assert.assertNotNull(deserializedModel);
120+
}
121+
112122
@Test
113123
public void predictWithoutModel() {
114124
exceptionRule.expect(IllegalArgumentException.class);

0 commit comments

Comments
 (0)