@@ -23,7 +23,15 @@ import org.apache.spark.ml.{ Estimator, Model }
2323import org .apache .spark .ml .linalg .{ Vector , Vectors }
2424import org .apache .spark .ml .param ._
2525import org .apache .spark .ml .param .shared ._
26- import org .apache .spark .ml .util .{ DefaultParamsReadable , DefaultParamsWritable , Identifiable , MLReadable , MLReader , MLWritable , MLWriter }
26+ import org .apache .spark .ml .util .{
27+ DefaultParamsReadable ,
28+ DefaultParamsWritable ,
29+ Identifiable ,
30+ MLReadable ,
31+ MLReader ,
32+ MLWritable ,
33+ MLWriter
34+ }
2735import org .apache .spark .sql .{ DataFrame , Dataset }
2836import org .apache .spark .sql .functions ._
2937import org .apache .spark .sql .types .StructType
@@ -47,8 +55,8 @@ trait AgglomerativeBregmanParams
4755 )
4856 def getNumClusters : Int = $(numClusters)
4957
50- /** Distance threshold for merging (alternative to numClusters).
51- * If set > 0, clustering stops when min merge distance exceeds threshold.
58+ /** Distance threshold for merging (alternative to numClusters). If set > 0, clustering stops when
59+ * min merge distance exceeds threshold.
5260 */
5361 final val distanceThreshold : DoubleParam = new DoubleParam (
5462 this ,
@@ -103,16 +111,14 @@ trait AgglomerativeBregmanParams
103111
104112/** Agglomerative (bottom-up) hierarchical clustering with Bregman divergences.
105113 *
106- * Starts with each point as its own cluster and iteratively merges the
107- * closest pair of clusters until the desired number is reached.
114+ * Starts with each point as its own cluster and iteratively merges the closest pair of clusters
115+ * until the desired number is reached.
108116 *
109117 * ==Algorithm==
110118 *
111- * 1. Initialize: Each point is a singleton cluster
112- * 2. Compute pairwise distances/divergences between all clusters
113- * 3. Find and merge the closest pair
114- * 4. Update distances to the merged cluster
115- * 5. Repeat until numClusters reached or distanceThreshold exceeded
119+ * 1. Initialize: Each point is a singleton cluster 2. Compute pairwise distances/divergences
120+ * between all clusters 3. Find and merge the closest pair 4. Update distances to the merged
121+ * cluster 5. Repeat until numClusters reached or distanceThreshold exceeded
116122 *
117123 * ==Linkage Criteria==
118124 *
@@ -139,9 +145,9 @@ trait AgglomerativeBregmanParams
139145 *
140146 * ==Scalability Note==
141147 *
142- * Standard agglomerative clustering has O(n³) or O(n²log n) complexity.
143- * This implementation is suitable for datasets up to ~10,000 points.
144- * For larger datasets, consider [[ BisectingKMeans ]] (top-down approach).
148+ * Standard agglomerative clustering has O(n³) or O(n²log n) complexity. This implementation is
149+ * suitable for datasets up to ~10,000 points. For larger datasets, consider [[ BisectingKMeans ]]
150+ * (top-down approach).
145151 *
146152 * @see
147153 * [[BisectingKMeans ]] for top-down hierarchical clustering
@@ -155,14 +161,14 @@ class AgglomerativeBregman(override val uid: String)
155161 def this () = this (Identifiable .randomUID(" agglomerative" ))
156162
157163 // Parameter setters
158- def setNumClusters (value : Int ): this .type = set(numClusters, value)
164+ def setNumClusters (value : Int ): this .type = set(numClusters, value)
159165 def setDistanceThreshold (value : Double ): this .type = set(distanceThreshold, value)
160- def setLinkage (value : String ): this .type = set(linkage, value)
161- def setDivergence (value : String ): this .type = set(divergence, value)
162- def setSmoothing (value : Double ): this .type = set(smoothing, value)
163- def setFeaturesCol (value : String ): this .type = set(featuresCol, value)
164- def setPredictionCol (value : String ): this .type = set(predictionCol, value)
165- def setSeed (value : Long ): this .type = set(seed, value)
166+ def setLinkage (value : String ): this .type = set(linkage, value)
167+ def setDivergence (value : String ): this .type = set(divergence, value)
168+ def setSmoothing (value : Double ): this .type = set(smoothing, value)
169+ def setFeaturesCol (value : String ): this .type = set(featuresCol, value)
170+ def setPredictionCol (value : String ): this .type = set(predictionCol, value)
171+ def setSeed (value : Long ): this .type = set(seed, value)
166172
167173 override def fit (dataset : Dataset [_]): AgglomerativeBregmanModel = {
168174 transformSchema(dataset.schema, logging = true )
@@ -250,8 +256,8 @@ class AgglomerativeBregman(override val uid: String)
250256 }
251257
252258 def union (x : Int , y : Int ): Int = {
253- val px = find(x)
254- val py = find(y)
259+ val px = find(x)
260+ val py = find(y)
255261 if (px == py) return px
256262 val (root, child) = if (rank(px) < rank(py)) (py, px) else (px, py)
257263 parent(child) = root
@@ -331,8 +337,8 @@ class AgglomerativeBregman(override val uid: String)
331337 val assignments = Array .tabulate(n)(i => find(i))
332338
333339 // Relabel to 0..k-1
334- val uniqueLabels = assignments.distinct.sorted
335- val labelMap = uniqueLabels.zipWithIndex.toMap
340+ val uniqueLabels = assignments.distinct.sorted
341+ val labelMap = uniqueLabels.zipWithIndex.toMap
336342 val finalAssignments = assignments.map(labelMap)
337343
338344 (finalAssignments, dendrogram.toArray, mergeDistances.toArray)
@@ -380,8 +386,8 @@ class AgglomerativeBregman(override val uid: String)
380386 val centroidB = computeCentroid(clusterB, points, kernel)
381387
382388 // ESS increase = |A||B|/(|A|+|B|) * ||μ_A - μ_B||²
383- val nA = clusterA.size.toDouble
384- val nB = clusterB.size.toDouble
389+ val nA = clusterA.size.toDouble
390+ val nB = clusterB.size.toDouble
385391 val dist = kernel.divergence(centroidA, centroidB)
386392 (nA * nB / (nA + nB)) * dist
387393
@@ -561,24 +567,26 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
561567 val dendrogramData = instance.dendrogram.zipWithIndex.map { case (m, i) =>
562568 (i, m.cluster1, m.cluster2, m.merged, m.distance)
563569 }.toSeq
564- spark.createDataFrame(dendrogramData)
570+ spark
571+ .createDataFrame(dendrogramData)
565572 .toDF(" id" , " cluster1" , " cluster2" , " merged" , " distance" )
566- .write.parquet(s " $path/dendrogram " )
573+ .write
574+ .parquet(s " $path/dendrogram " )
567575
568576 val params : Map [String , Any ] = Map (
569- " k" -> instance.k,
570- " featuresCol" -> instance.getOrDefault(instance.featuresCol),
577+ " k" -> instance.k,
578+ " featuresCol" -> instance.getOrDefault(instance.featuresCol),
571579 " predictionCol" -> instance.getOrDefault(instance.predictionCol),
572- " divergence" -> instance.modelDivergence,
573- " smoothing" -> instance.modelSmoothing,
574- " linkage" -> instance.modelLinkage
580+ " divergence" -> instance.modelDivergence,
581+ " smoothing" -> instance.modelSmoothing,
582+ " linkage" -> instance.modelLinkage
575583 )
576584
577585 val k = instance.k
578586 val dim = instance.clusterCenters.headOption.map(_.size).getOrElse(0 )
579587
580588 implicit val formats : DefaultFormats .type = DefaultFormats
581- val metaObj : Map [String , Any ] = Map (
589+ val metaObj : Map [String , Any ] = Map (
582590 " layoutVersion" -> LayoutVersion ,
583591 " algo" -> " AgglomerativeBregmanModel" ,
584592 " sparkMLVersion" -> org.apache.spark.SPARK_VERSION ,
@@ -609,7 +617,9 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
609617 }
610618 }
611619
612- private class AgglomerativeBregmanModelReader extends MLReader [AgglomerativeBregmanModel ] with Logging {
620+ private class AgglomerativeBregmanModelReader
621+ extends MLReader [AgglomerativeBregmanModel ]
622+ with Logging {
613623 import com .massivedatascience .clusterer .ml .df .persistence .PersistenceLayoutV1 ._
614624 import org .json4s .DefaultFormats
615625 import org .json4s .jackson .JsonMethods
@@ -618,9 +628,9 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
618628 val spark = sparkSession
619629 logInfo(s " Loading AgglomerativeBregmanModel from $path" )
620630
621- val metaStr = readMetadata(path)
631+ val metaStr = readMetadata(path)
622632 implicit val formats : DefaultFormats .type = DefaultFormats
623- val metaJ = JsonMethods .parse(metaStr)
633+ val metaJ = JsonMethods .parse(metaStr)
624634
625635 val layoutVersion = (metaJ \ " layoutVersion" ).extract[Int ]
626636 val k = (metaJ \ " k" ).extract[Int ]
@@ -633,7 +643,8 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
633643
634644 val centers = rows.sortBy(_.getInt(0 )).map(_.getAs[Vector ](" vector" ))
635645
636- val dendrogram = spark.read.parquet(s " $path/dendrogram " )
646+ val dendrogram = spark.read
647+ .parquet(s " $path/dendrogram " )
637648 .orderBy(" id" )
638649 .collect()
639650 .map(r => MergeStep (r.getInt(1 ), r.getInt(2 ), r.getInt(3 ), r.getDouble(4 )))
0 commit comments