Skip to content

Commit a36b78b

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-22450][CORE][MLLIB][FOLLOWUP] safely register class for mllib - LabeledPoint/VectorWithNorm/TreePoint
## What changes were proposed in this pull request? register following classes in Kryo: `org.apache.spark.mllib.regression.LabeledPoint` `org.apache.spark.mllib.clustering.VectorWithNorm` `org.apache.spark.ml.feature.LabeledPoint` `org.apache.spark.ml.tree.impl.TreePoint` `org.apache.spark.ml.tree.impl.BaggedPoint` seems also need to be registered, but I don't know how to do it in this safe way. WeichenXu123 cloud-fan ## How was this patch tested? added tests Author: Zheng RuiFeng <[email protected]> Closes #19950 from zhengruifeng/labeled_kryo.
1 parent c6f01ca commit a36b78b

File tree

6 files changed

+135
-26
lines changed

6 files changed

+135
-26
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,25 @@ class KryoSerializer(conf: SparkConf)
181181

182182
// We can't load those class directly in order to avoid unnecessary jar dependencies.
183183
// We load them safely, ignore it if the class not found.
184-
Seq("org.apache.spark.mllib.linalg.Vector",
185-
"org.apache.spark.mllib.linalg.DenseVector",
186-
"org.apache.spark.mllib.linalg.SparseVector",
187-
"org.apache.spark.mllib.linalg.Matrix",
188-
"org.apache.spark.mllib.linalg.DenseMatrix",
189-
"org.apache.spark.mllib.linalg.SparseMatrix",
190-
"org.apache.spark.ml.linalg.Vector",
184+
Seq(
185+
"org.apache.spark.ml.feature.Instance",
186+
"org.apache.spark.ml.feature.LabeledPoint",
187+
"org.apache.spark.ml.feature.OffsetInstance",
188+
"org.apache.spark.ml.linalg.DenseMatrix",
191189
"org.apache.spark.ml.linalg.DenseVector",
192-
"org.apache.spark.ml.linalg.SparseVector",
193190
"org.apache.spark.ml.linalg.Matrix",
194-
"org.apache.spark.ml.linalg.DenseMatrix",
195191
"org.apache.spark.ml.linalg.SparseMatrix",
196-
"org.apache.spark.ml.feature.Instance",
197-
"org.apache.spark.ml.feature.OffsetInstance"
192+
"org.apache.spark.ml.linalg.SparseVector",
193+
"org.apache.spark.ml.linalg.Vector",
194+
"org.apache.spark.ml.tree.impl.TreePoint",
195+
"org.apache.spark.mllib.clustering.VectorWithNorm",
196+
"org.apache.spark.mllib.linalg.DenseMatrix",
197+
"org.apache.spark.mllib.linalg.DenseVector",
198+
"org.apache.spark.mllib.linalg.Matrix",
199+
"org.apache.spark.mllib.linalg.SparseMatrix",
200+
"org.apache.spark.mllib.linalg.SparseVector",
201+
"org.apache.spark.mllib.linalg.Vector",
202+
"org.apache.spark.mllib.regression.LabeledPoint"
198203
).foreach { name =>
199204
try {
200205
val clazz = Utils.classForName(name)

mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala renamed to mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,29 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import scala.reflect.ClassTag
21-
2220
import org.apache.spark.{SparkConf, SparkFunSuite}
2321
import org.apache.spark.ml.linalg.Vectors
2422
import org.apache.spark.serializer.KryoSerializer
2523

26-
class InstanceSuit extends SparkFunSuite{
24+
class InstanceSuite extends SparkFunSuite{
2725
test("Kryo class register") {
2826
val conf = new SparkConf(false)
2927
conf.set("spark.kryo.registrationRequired", "true")
3028

31-
val ser = new KryoSerializer(conf)
32-
val serInstance = new KryoSerializer(conf).newInstance()
33-
34-
def check[T: ClassTag](t: T) {
35-
assert(serInstance.deserialize[T](serInstance.serialize(t)) === t)
36-
}
29+
val ser = new KryoSerializer(conf).newInstance()
3730

3831
val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0))
3932
val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse)
33+
Seq(instance1, instance2).foreach { i =>
34+
val i2 = ser.deserialize[Instance](ser.serialize(i))
35+
assert(i === i2)
36+
}
37+
4038
val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0))
4139
val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse)
42-
check(instance1)
43-
check(instance2)
44-
check(oInstance1)
45-
check(oInstance2)
40+
Seq(oInstance1, oInstance2).foreach { o =>
41+
val o2 = ser.deserialize[OffsetInstance](ser.serialize(o))
42+
assert(o === o2)
43+
}
4644
}
4745
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.{SparkConf, SparkFunSuite}
21+
import org.apache.spark.ml.linalg.Vectors
22+
import org.apache.spark.serializer.KryoSerializer
23+
24+
class LabeledPointSuite extends SparkFunSuite {
25+
test("Kryo class register") {
26+
val conf = new SparkConf(false)
27+
conf.set("spark.kryo.registrationRequired", "true")
28+
29+
val ser = new KryoSerializer(conf).newInstance()
30+
31+
val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0)))
32+
val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), Array(1.0, 2.0)))
33+
34+
Seq(labeled1, labeled2).foreach { l =>
35+
val l2 = ser.deserialize[LabeledPoint](ser.serialize(l))
36+
assert(l === l2)
37+
}
38+
}
39+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.tree.impl
19+
20+
import org.apache.spark.{SparkConf, SparkFunSuite}
21+
import org.apache.spark.serializer.KryoSerializer
22+
23+
class TreePointSuite extends SparkFunSuite {
24+
test("Kryo class register") {
25+
val conf = new SparkConf(false)
26+
conf.set("spark.kryo.registrationRequired", "true")
27+
28+
val ser = new KryoSerializer(conf).newInstance()
29+
30+
val point = new TreePoint(1.0, Array(1, 2, 3))
31+
val point2 = ser.deserialize[TreePoint](ser.serialize(point))
32+
assert(point.label === point2.label)
33+
assert(point.binnedFeatures === point2.binnedFeatures)
34+
}
35+
}

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package org.apache.spark.mllib.clustering
1919

2020
import scala.util.Random
2121

22-
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.{SparkConf, SparkFunSuite}
2323
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2424
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
2525
import org.apache.spark.mllib.util.TestingUtils._
26+
import org.apache.spark.serializer.KryoSerializer
2627
import org.apache.spark.util.Utils
2728

2829
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -311,6 +312,21 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
311312
assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
312313
}
313314

315+
test("Kryo class register") {
316+
val conf = new SparkConf(false)
317+
conf.set("spark.kryo.registrationRequired", "true")
318+
319+
val ser = new KryoSerializer(conf).newInstance()
320+
321+
val vec1 = new VectorWithNorm(Vectors.dense(Array(1.0, 2.0)))
322+
val vec2 = new VectorWithNorm(Vectors.sparse(10, Array(5, 8), Array(1.0, 2.0)))
323+
324+
Seq(vec1, vec2).foreach { v =>
325+
val v2 = ser.deserialize[VectorWithNorm](ser.serialize(v))
326+
assert(v2.norm === v.norm)
327+
assert(v2.vector === v.vector)
328+
}
329+
}
314330
}
315331

316332
object KMeansSuite extends SparkFunSuite {

mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkConf, SparkFunSuite}
2121
import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}
2222
import org.apache.spark.mllib.linalg.Vectors
23+
import org.apache.spark.serializer.KryoSerializer
2324

2425
class LabeledPointSuite extends SparkFunSuite {
2526

@@ -53,4 +54,19 @@ class LabeledPointSuite extends SparkFunSuite {
5354
assert(p1 === LabeledPoint.fromML(p2))
5455
}
5556
}
57+
58+
test("Kryo class register") {
59+
val conf = new SparkConf(false)
60+
conf.set("spark.kryo.registrationRequired", "true")
61+
62+
val ser = new KryoSerializer(conf).newInstance()
63+
64+
val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0)))
65+
val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), Array(1.0, 2.0)))
66+
67+
Seq(labeled1, labeled2).foreach { l =>
68+
val l2 = ser.deserialize[LabeledPoint](ser.serialize(l))
69+
assert(l === l2)
70+
}
71+
}
5672
}

0 commit comments

Comments
 (0)