Skip to content

Commit e82cb68

Browse files
holdenkdbtsai
authored andcommitted
[SPARK-11237][ML] Add pmml export for k-means in Spark ML
## What changes were proposed in this pull request? Adding PMML export to Spark ML's KMeans Model. ## How was this patch tested? New unit test for Spark ML PMML export based on the old Spark MLlib unit test. Author: Holden Karau <[email protected]> Closes apache#20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans.
1 parent 770add8 commit e82cb68

File tree

4 files changed

+83
-30
lines changed

4 files changed

+83
-30
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
2-
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
2+
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
3+
org.apache.spark.ml.clustering.InternalKMeansModelWriter
4+
org.apache.spark.ml.clustering.PMMLKMeansModelWriter

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.hadoop.fs.Path
2123

2224
import org.apache.spark.SparkException
2325
import org.apache.spark.annotation.{Experimental, Since}
24-
import org.apache.spark.ml.{Estimator, Model}
26+
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
2527
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
2628
import org.apache.spark.ml.param._
2729
import org.apache.spark.ml.param.shared._
@@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
3032
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
3133
import org.apache.spark.mllib.linalg.VectorImplicits._
3234
import org.apache.spark.rdd.RDD
33-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
35+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
3436
import org.apache.spark.sql.functions.{col, udf}
3537
import org.apache.spark.sql.types.{IntegerType, StructType}
3638
import org.apache.spark.storage.StorageLevel
@@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
103105
@Since("1.5.0")
104106
class KMeansModel private[ml] (
105107
@Since("1.5.0") override val uid: String,
106-
private val parentModel: MLlibKMeansModel)
107-
extends Model[KMeansModel] with KMeansParams with MLWritable {
108+
private[clustering] val parentModel: MLlibKMeansModel)
109+
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
108110

109111
@Since("1.5.0")
110112
override def copy(extra: ParamMap): KMeansModel = {
@@ -152,14 +154,14 @@ class KMeansModel private[ml] (
152154
}
153155

154156
/**
155-
* Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
157+
* Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
156158
*
157159
* For [[KMeansModel]], this does NOT currently save the training [[summary]].
158160
* An option to save [[summary]] may be added in the future.
159161
*
160162
*/
161163
@Since("1.6.0")
162-
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
164+
override def write: GeneralMLWriter = new GeneralMLWriter(this)
163165

164166
private var trainingSummary: Option[KMeansSummary] = None
165167

@@ -185,6 +187,47 @@ class KMeansModel private[ml] (
185187
}
186188
}
187189

190+
/** Helper class for storing model data */
191+
private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
192+
193+
194+
/** A writer for KMeans that handles the "internal" (or default) format */
195+
private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
196+
197+
override def format(): String = "internal"
198+
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
199+
200+
override def write(path: String, sparkSession: SparkSession,
201+
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
202+
val instance = stage.asInstanceOf[KMeansModel]
203+
val sc = sparkSession.sparkContext
204+
// Save metadata and Params
205+
DefaultParamsWriter.saveMetadata(instance, path, sc)
206+
// Save model data: cluster centers
207+
val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
208+
case (center, idx) =>
209+
ClusterData(idx, center)
210+
}
211+
val dataPath = new Path(path, "data").toString
212+
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
213+
}
214+
}
215+
216+
/** A writer for KMeans that handles the "pmml" format */
217+
private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
218+
219+
override def format(): String = "pmml"
220+
override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
221+
222+
override def write(path: String, sparkSession: SparkSession,
223+
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
224+
val instance = stage.asInstanceOf[KMeansModel]
225+
val sc = sparkSession.sparkContext
226+
instance.parentModel.toPMML(sc, path)
227+
}
228+
}
229+
230+
188231
@Since("1.6.0")
189232
object KMeansModel extends MLReadable[KMeansModel] {
190233

@@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
194237
@Since("1.6.0")
195238
override def load(path: String): KMeansModel = super.load(path)
196239

197-
/** Helper class for storing model data */
198-
private case class Data(clusterIdx: Int, clusterCenter: Vector)
199-
200240
/**
201241
* We store all cluster centers in a single row and use this class to store model data by
202242
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
203243
*/
204244
private case class OldData(clusterCenters: Array[OldVector])
205245

206-
/** [[MLWriter]] instance for [[KMeansModel]] */
207-
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
208-
209-
override protected def saveImpl(path: String): Unit = {
210-
// Save metadata and Params
211-
DefaultParamsWriter.saveMetadata(instance, path, sc)
212-
// Save model data: cluster centers
213-
val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
214-
Data(idx, center)
215-
}
216-
val dataPath = new Path(path, "data").toString
217-
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
218-
}
219-
}
220-
221246
private class KMeansModelReader extends MLReader[KMeansModel] {
222247

223248
/** Checked against metadata when loading model */
@@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
232257
val dataPath = new Path(path, "data").toString
233258

234259
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
235-
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
260+
val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
236261
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
237262
} else {
238263
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter
746746

747747
/** A writer for LinearRegression that handles the "pmml" format */
748748
private class PMMLLinearRegressionModelWriter
749-
extends MLWriterFormat with MLFormatRegister {
749+
extends MLWriterFormat with MLFormatRegister {
750750

751751
override def format(): String = "pmml"
752752

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering
1919

2020
import scala.util.Random
2121

22+
import org.dmg.pmml.{ClusteringModel, PMML}
23+
2224
import org.apache.spark.{SparkException, SparkFunSuite}
2325
import org.apache.spark.ml.linalg.{Vector, Vectors}
2426
import org.apache.spark.ml.param.ParamMap
25-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
26-
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
27+
import org.apache.spark.ml.util._
28+
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
29+
KMeansModel => MLlibKMeansModel}
30+
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
2731
import org.apache.spark.mllib.util.MLlibTestSparkContext
2832
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
2933

3034
private[clustering] case class TestRow(features: Vector)
3135

32-
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
36+
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
37+
with PMMLReadWriteTest {
3338

3439
final val k = 5
3540
@transient var dataset: Dataset[_] = _
@@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
202207
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
203208
KMeansSuite.allParamSettings, checkModelData)
204209
}
210+
211+
test("pmml export") {
212+
val clusterCenters = Array(
213+
MLlibVectors.dense(1.0, 2.0, 6.0),
214+
MLlibVectors.dense(1.0, 3.0, 0.0),
215+
MLlibVectors.dense(1.0, 4.0, 6.0))
216+
val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
217+
val kmeansModel = new KMeansModel("", oldKmeansModel)
218+
def checkModel(pmml: PMML): Unit = {
219+
// Check the header descripiton is what we expect
220+
assert(pmml.getHeader.getDescription === "k-means clustering")
221+
// check that the number of fields match the single vector size
222+
assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
223+
// This verify that there is a model attached to the pmml object and the model is a clustering
224+
// one. It also verifies that the pmml model has the same number of clusters of the spark
225+
// model.
226+
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
227+
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
228+
}
229+
testPMMLWrite(sc, kmeansModel, checkModel)
230+
}
205231
}
206232

207233
object KMeansSuite {

0 commit comments

Comments
 (0)