Skip to content

Commit a471880

Browse files
wangmiao1981jkbradley
authored andcommitted
[SPARK-24026][ML] Add Power Iteration Clustering to spark.ml
## What changes were proposed in this pull request? This PR adds PowerIterationClustering as a Transformer to spark.ml. In the transform method, it calls spark.mllib's PowerIterationClustering.run() method and transforms the return value assignments (the Kmeans output of the pseudo-eigenvector) as a DataFrame (id: LongType, cluster: IntegerType). This PR is copied and modified from apache#15770 The primary author is wangmiao1981 ## How was this patch tested? This PR has 2 types of tests: * Copies of tests from spark.mllib's PIC tests * New tests specific to the spark.ml APIs Author: [email protected] <[email protected]> Author: wangmiao1981 <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes apache#21090 from jkbradley/wangmiao1981-pic.
1 parent 6e19f76 commit a471880

File tree

2 files changed

+494
-0
lines changed

2 files changed

+494
-0
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.clustering
19+
20+
import org.apache.spark.annotation.{Experimental, Since}
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.param._
23+
import org.apache.spark.ml.param.shared._
24+
import org.apache.spark.ml.util._
25+
import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering}
26+
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
28+
import org.apache.spark.sql.functions.col
29+
import org.apache.spark.sql.types._
30+
31+
/**
32+
* Common params for PowerIterationClustering
33+
*/
34+
private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
35+
with HasPredictionCol {
36+
37+
/**
38+
* The number of clusters to create (k). Must be &gt; 1. Default: 2.
39+
* @group param
40+
*/
41+
@Since("2.4.0")
42+
final val k = new IntParam(this, "k", "The number of clusters to create. " +
43+
"Must be > 1.", ParamValidators.gt(1))
44+
45+
/** @group getParam */
46+
@Since("2.4.0")
47+
def getK: Int = $(k)
48+
49+
/**
50+
* Param for the initialization algorithm. This can be either "random" to use a random vector
51+
* as vertex properties, or "degree" to use a normalized sum of similarities with other vertices.
52+
* Default: random.
53+
* @group expertParam
54+
*/
55+
@Since("2.4.0")
56+
final val initMode = {
57+
val allowedParams = ParamValidators.inArray(Array("random", "degree"))
58+
new Param[String](this, "initMode", "The initialization algorithm. This can be either " +
59+
"'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " +
60+
"of similarities with other vertices. Supported options: 'random' and 'degree'.",
61+
allowedParams)
62+
}
63+
64+
/** @group expertGetParam */
65+
@Since("2.4.0")
66+
def getInitMode: String = $(initMode)
67+
68+
/**
69+
* Param for the name of the input column for vertex IDs.
70+
* Default: "id"
71+
* @group param
72+
*/
73+
@Since("2.4.0")
74+
val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.",
75+
(value: String) => value.nonEmpty)
76+
77+
setDefault(idCol, "id")
78+
79+
/** @group getParam */
80+
@Since("2.4.0")
81+
def getIdCol: String = getOrDefault(idCol)
82+
83+
/**
84+
* Param for the name of the input column for neighbors in the adjacency list representation.
85+
* Default: "neighbors"
86+
* @group param
87+
*/
88+
@Since("2.4.0")
89+
val neighborsCol = new Param[String](this, "neighborsCol",
90+
"Name of the input column for neighbors in the adjacency list representation.",
91+
(value: String) => value.nonEmpty)
92+
93+
setDefault(neighborsCol, "neighbors")
94+
95+
/** @group getParam */
96+
@Since("2.4.0")
97+
def getNeighborsCol: String = $(neighborsCol)
98+
99+
/**
100+
* Param for the name of the input column for neighbors in the adjacency list representation.
101+
* Default: "similarities"
102+
* @group param
103+
*/
104+
@Since("2.4.0")
105+
val similaritiesCol = new Param[String](this, "similaritiesCol",
106+
"Name of the input column for neighbors in the adjacency list representation.",
107+
(value: String) => value.nonEmpty)
108+
109+
setDefault(similaritiesCol, "similarities")
110+
111+
/** @group getParam */
112+
@Since("2.4.0")
113+
def getSimilaritiesCol: String = $(similaritiesCol)
114+
115+
protected def validateAndTransformSchema(schema: StructType): StructType = {
116+
SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType))
117+
SchemaUtils.checkColumnTypes(schema, $(neighborsCol),
118+
Seq(ArrayType(IntegerType, containsNull = false),
119+
ArrayType(LongType, containsNull = false)))
120+
SchemaUtils.checkColumnTypes(schema, $(similaritiesCol),
121+
Seq(ArrayType(FloatType, containsNull = false),
122+
ArrayType(DoubleType, containsNull = false)))
123+
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
124+
}
125+
}
126+
127+
/**
128+
* :: Experimental ::
129+
* Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
130+
* <a href=http://www.icml2010.org/papers/387.pdf>Lin and Cohen</a>. From the abstract:
131+
* PIC finds a very low-dimensional embedding of a dataset using truncated power
132+
* iteration on a normalized pair-wise similarity matrix of the data.
133+
*
134+
* PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix
135+
* is a symmetric matrix whose entries are non-negative similarities between items.
136+
* PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes:
137+
* - `idCol`: vertex ID
138+
* - `neighborsCol`: neighbors of vertex in `idCol`
139+
* - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex
140+
* in `idCol` and each neighbor in `neighborsCol`
141+
* PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol`
142+
* containing the cluster assignment in `[0,k)` for each row (vertex).
143+
*
144+
* Notes:
145+
* - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation.
146+
* Transform runs the iterative PIC algorithm to cluster the whole input dataset.
147+
* - Input validation: This validates that similarities are non-negative but does NOT validate
148+
* that the input matrix is symmetric.
149+
*
150+
* @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering>
151+
* Spectral clustering (Wikipedia)</a>
152+
*/
153+
@Since("2.4.0")
154+
@Experimental
155+
class PowerIterationClustering private[clustering] (
156+
@Since("2.4.0") override val uid: String)
157+
extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable {
158+
159+
setDefault(
160+
k -> 2,
161+
maxIter -> 20,
162+
initMode -> "random")
163+
164+
@Since("2.4.0")
165+
def this() = this(Identifiable.randomUID("PowerIterationClustering"))
166+
167+
/** @group setParam */
168+
@Since("2.4.0")
169+
def setPredictionCol(value: String): this.type = set(predictionCol, value)
170+
171+
/** @group setParam */
172+
@Since("2.4.0")
173+
def setK(value: Int): this.type = set(k, value)
174+
175+
/** @group expertSetParam */
176+
@Since("2.4.0")
177+
def setInitMode(value: String): this.type = set(initMode, value)
178+
179+
/** @group setParam */
180+
@Since("2.4.0")
181+
def setMaxIter(value: Int): this.type = set(maxIter, value)
182+
183+
/** @group setParam */
184+
@Since("2.4.0")
185+
def setIdCol(value: String): this.type = set(idCol, value)
186+
187+
/** @group setParam */
188+
@Since("2.4.0")
189+
def setNeighborsCol(value: String): this.type = set(neighborsCol, value)
190+
191+
/** @group setParam */
192+
@Since("2.4.0")
193+
def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value)
194+
195+
@Since("2.4.0")
196+
override def transform(dataset: Dataset[_]): DataFrame = {
197+
transformSchema(dataset.schema, logging = true)
198+
199+
val sparkSession = dataset.sparkSession
200+
val idColValue = $(idCol)
201+
val rdd: RDD[(Long, Long, Double)] =
202+
dataset.select(
203+
col($(idCol)).cast(LongType),
204+
col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)),
205+
col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false))
206+
).rdd.flatMap {
207+
case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) =>
208+
require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " +
209+
s"equal to the the length of the neighbor similarity list. Row for ID " +
210+
s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " +
211+
s"of length ${sims.length}.")
212+
nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map {
213+
case (nbr, similarity) => (id, nbr, similarity)
214+
}
215+
}
216+
val algorithm = new MLlibPowerIterationClustering()
217+
.setK($(k))
218+
.setInitializationMode($(initMode))
219+
.setMaxIterations($(maxIter))
220+
val model = algorithm.run(rdd)
221+
222+
val predictionsRDD: RDD[Row] = model.assignments.map { assignment =>
223+
Row(assignment.id, assignment.cluster)
224+
}
225+
226+
val predictionsSchema = StructType(Seq(
227+
StructField($(idCol), LongType, nullable = false),
228+
StructField($(predictionCol), IntegerType, nullable = false)))
229+
val predictions = {
230+
val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)
231+
dataset.schema($(idCol)).dataType match {
232+
case _: LongType =>
233+
uncastPredictions
234+
case otherType =>
235+
uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol)))
236+
}
237+
}
238+
239+
dataset.join(predictions, $(idCol))
240+
}
241+
242+
@Since("2.4.0")
243+
override def transformSchema(schema: StructType): StructType = {
244+
validateAndTransformSchema(schema)
245+
}
246+
247+
@Since("2.4.0")
248+
override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra)
249+
}
250+
251+
@Since("2.4.0")
252+
object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] {
253+
254+
@Since("2.4.0")
255+
override def load(path: String): PowerIterationClustering = super.load(path)
256+
}

0 commit comments

Comments
 (0)