17
17
18
18
package org .apache .spark .ml .clustering
19
19
20
+ import scala .collection .mutable
21
+
20
22
import org .apache .hadoop .fs .Path
21
23
22
24
import org .apache .spark .SparkException
23
25
import org .apache .spark .annotation .{Experimental , Since }
24
- import org .apache .spark .ml .{Estimator , Model }
26
+ import org .apache .spark .ml .{Estimator , Model , PipelineStage }
25
27
import org .apache .spark .ml .linalg .{Vector , VectorUDT }
26
28
import org .apache .spark .ml .param ._
27
29
import org .apache .spark .ml .param .shared ._
@@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
30
32
import org .apache .spark .mllib .linalg .{Vector => OldVector , Vectors => OldVectors }
31
33
import org .apache .spark .mllib .linalg .VectorImplicits ._
32
34
import org .apache .spark .rdd .RDD
33
- import org .apache .spark .sql .{DataFrame , Dataset , Row }
35
+ import org .apache .spark .sql .{DataFrame , Dataset , Row , SparkSession }
34
36
import org .apache .spark .sql .functions .{col , udf }
35
37
import org .apache .spark .sql .types .{IntegerType , StructType }
36
38
import org .apache .spark .storage .StorageLevel
@@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
103
105
@ Since (" 1.5.0" )
104
106
class KMeansModel private [ml] (
105
107
@ Since (" 1.5.0" ) override val uid : String ,
106
- private val parentModel : MLlibKMeansModel )
107
- extends Model [KMeansModel ] with KMeansParams with MLWritable {
108
+ private [clustering] val parentModel : MLlibKMeansModel )
109
+ extends Model [KMeansModel ] with KMeansParams with GeneralMLWritable {
108
110
109
111
@ Since (" 1.5.0" )
110
112
override def copy (extra : ParamMap ): KMeansModel = {
@@ -152,14 +154,14 @@ class KMeansModel private[ml] (
152
154
}
153
155
154
156
/**
155
- * Returns a [[org.apache.spark.ml.util.MLWriter ]] instance for this ML instance.
157
+ * Returns a [[org.apache.spark.ml.util.GeneralMLWriter ]] instance for this ML instance.
156
158
*
157
159
* For [[KMeansModel ]], this does NOT currently save the training [[summary ]].
158
160
* An option to save [[summary ]] may be added in the future.
159
161
*
160
162
*/
161
163
@ Since (" 1.6.0" )
162
- override def write : MLWriter = new KMeansModel . KMeansModelWriter (this )
164
+ override def write : GeneralMLWriter = new GeneralMLWriter (this )
163
165
164
166
private var trainingSummary : Option [KMeansSummary ] = None
165
167
@@ -185,6 +187,47 @@ class KMeansModel private[ml] (
185
187
}
186
188
}
187
189
190
+ /** Helper class for storing model data */
191
+ private case class ClusterData (clusterIdx : Int , clusterCenter : Vector )
192
+
193
+
194
+ /** A writer for KMeans that handles the "internal" (or default) format */
195
+ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
196
+
197
+ override def format (): String = " internal"
198
+ override def stageName (): String = " org.apache.spark.ml.clustering.KMeansModel"
199
+
200
+ override def write (path : String , sparkSession : SparkSession ,
201
+ optionMap : mutable.Map [String , String ], stage : PipelineStage ): Unit = {
202
+ val instance = stage.asInstanceOf [KMeansModel ]
203
+ val sc = sparkSession.sparkContext
204
+ // Save metadata and Params
205
+ DefaultParamsWriter .saveMetadata(instance, path, sc)
206
+ // Save model data: cluster centers
207
+ val data : Array [ClusterData ] = instance.clusterCenters.zipWithIndex.map {
208
+ case (center, idx) =>
209
+ ClusterData (idx, center)
210
+ }
211
+ val dataPath = new Path (path, " data" ).toString
212
+ sparkSession.createDataFrame(data).repartition(1 ).write.parquet(dataPath)
213
+ }
214
+ }
215
+
216
+ /** A writer for KMeans that handles the "pmml" format */
217
+ private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
218
+
219
+ override def format (): String = " pmml"
220
+ override def stageName (): String = " org.apache.spark.ml.clustering.KMeansModel"
221
+
222
+ override def write (path : String , sparkSession : SparkSession ,
223
+ optionMap : mutable.Map [String , String ], stage : PipelineStage ): Unit = {
224
+ val instance = stage.asInstanceOf [KMeansModel ]
225
+ val sc = sparkSession.sparkContext
226
+ instance.parentModel.toPMML(sc, path)
227
+ }
228
+ }
229
+
230
+
188
231
@ Since (" 1.6.0" )
189
232
object KMeansModel extends MLReadable [KMeansModel ] {
190
233
@@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
194
237
@ Since (" 1.6.0" )
195
238
override def load (path : String ): KMeansModel = super .load(path)
196
239
197
- /** Helper class for storing model data */
198
- private case class Data (clusterIdx : Int , clusterCenter : Vector )
199
-
200
240
/**
201
241
* We store all cluster centers in a single row and use this class to store model data by
202
242
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
203
243
*/
204
244
private case class OldData (clusterCenters : Array [OldVector ])
205
245
206
- /** [[MLWriter ]] instance for [[KMeansModel ]] */
207
- private [KMeansModel ] class KMeansModelWriter (instance : KMeansModel ) extends MLWriter {
208
-
209
- override protected def saveImpl (path : String ): Unit = {
210
- // Save metadata and Params
211
- DefaultParamsWriter .saveMetadata(instance, path, sc)
212
- // Save model data: cluster centers
213
- val data : Array [Data ] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
214
- Data (idx, center)
215
- }
216
- val dataPath = new Path (path, " data" ).toString
217
- sparkSession.createDataFrame(data).repartition(1 ).write.parquet(dataPath)
218
- }
219
- }
220
-
221
246
private class KMeansModelReader extends MLReader [KMeansModel ] {
222
247
223
248
/** Checked against metadata when loading model */
@@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
232
257
val dataPath = new Path (path, " data" ).toString
233
258
234
259
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2 ) {
235
- val data : Dataset [Data ] = sparkSession.read.parquet(dataPath).as[Data ]
260
+ val data : Dataset [ClusterData ] = sparkSession.read.parquet(dataPath).as[ClusterData ]
236
261
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors .fromML)
237
262
} else {
238
263
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
0 commit comments