Skip to content

Commit 15c38ed

Browse files
derrickburnsclaude
andcommitted
style: Apply scalafmt formatting
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 2cbd3f6 commit 15c38ed

21 files changed

+392
-426
lines changed

src/main/scala/com/massivedatascience/clusterer/ml/AgglomerativeBregman.scala

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@ import org.apache.spark.ml.{ Estimator, Model }
2323
import org.apache.spark.ml.linalg.{ Vector, Vectors }
2424
import org.apache.spark.ml.param._
2525
import org.apache.spark.ml.param.shared._
26-
import org.apache.spark.ml.util.{ DefaultParamsReadable, DefaultParamsWritable, Identifiable, MLReadable, MLReader, MLWritable, MLWriter }
26+
import org.apache.spark.ml.util.{
27+
DefaultParamsReadable,
28+
DefaultParamsWritable,
29+
Identifiable,
30+
MLReadable,
31+
MLReader,
32+
MLWritable,
33+
MLWriter
34+
}
2735
import org.apache.spark.sql.{ DataFrame, Dataset }
2836
import org.apache.spark.sql.functions._
2937
import org.apache.spark.sql.types.StructType
@@ -47,8 +55,8 @@ trait AgglomerativeBregmanParams
4755
)
4856
def getNumClusters: Int = $(numClusters)
4957

50-
/** Distance threshold for merging (alternative to numClusters).
51-
* If set > 0, clustering stops when min merge distance exceeds threshold.
58+
/** Distance threshold for merging (alternative to numClusters). If set > 0, clustering stops when
59+
* min merge distance exceeds threshold.
5260
*/
5361
final val distanceThreshold: DoubleParam = new DoubleParam(
5462
this,
@@ -103,16 +111,14 @@ trait AgglomerativeBregmanParams
103111

104112
/** Agglomerative (bottom-up) hierarchical clustering with Bregman divergences.
105113
*
106-
* Starts with each point as its own cluster and iteratively merges the
107-
* closest pair of clusters until the desired number is reached.
114+
* Starts with each point as its own cluster and iteratively merges the closest pair of clusters
115+
* until the desired number is reached.
108116
*
109117
* ==Algorithm==
110118
*
111-
* 1. Initialize: Each point is a singleton cluster
112-
* 2. Compute pairwise distances/divergences between all clusters
113-
* 3. Find and merge the closest pair
114-
* 4. Update distances to the merged cluster
115-
* 5. Repeat until numClusters reached or distanceThreshold exceeded
119+
* 1. Initialize: Each point is a singleton cluster 2. Compute pairwise distances/divergences
120+
* between all clusters 3. Find and merge the closest pair 4. Update distances to the merged
121+
* cluster 5. Repeat until numClusters reached or distanceThreshold exceeded
116122
*
117123
* ==Linkage Criteria==
118124
*
@@ -139,9 +145,9 @@ trait AgglomerativeBregmanParams
139145
*
140146
* ==Scalability Note==
141147
*
142-
* Standard agglomerative clustering has O(n³) or O(n²log n) complexity.
143-
* This implementation is suitable for datasets up to ~10,000 points.
144-
* For larger datasets, consider [[BisectingKMeans]] (top-down approach).
148+
* Standard agglomerative clustering has O(n³) or O(n²log n) complexity. This implementation is
149+
* suitable for datasets up to ~10,000 points. For larger datasets, consider [[BisectingKMeans]]
150+
* (top-down approach).
145151
*
146152
* @see
147153
* [[BisectingKMeans]] for top-down hierarchical clustering
@@ -155,14 +161,14 @@ class AgglomerativeBregman(override val uid: String)
155161
def this() = this(Identifiable.randomUID("agglomerative"))
156162

157163
// Parameter setters
158-
def setNumClusters(value: Int): this.type = set(numClusters, value)
164+
def setNumClusters(value: Int): this.type = set(numClusters, value)
159165
def setDistanceThreshold(value: Double): this.type = set(distanceThreshold, value)
160-
def setLinkage(value: String): this.type = set(linkage, value)
161-
def setDivergence(value: String): this.type = set(divergence, value)
162-
def setSmoothing(value: Double): this.type = set(smoothing, value)
163-
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
164-
def setPredictionCol(value: String): this.type = set(predictionCol, value)
165-
def setSeed(value: Long): this.type = set(seed, value)
166+
def setLinkage(value: String): this.type = set(linkage, value)
167+
def setDivergence(value: String): this.type = set(divergence, value)
168+
def setSmoothing(value: Double): this.type = set(smoothing, value)
169+
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
170+
def setPredictionCol(value: String): this.type = set(predictionCol, value)
171+
def setSeed(value: Long): this.type = set(seed, value)
166172

167173
override def fit(dataset: Dataset[_]): AgglomerativeBregmanModel = {
168174
transformSchema(dataset.schema, logging = true)
@@ -250,8 +256,8 @@ class AgglomerativeBregman(override val uid: String)
250256
}
251257

252258
def union(x: Int, y: Int): Int = {
253-
val px = find(x)
254-
val py = find(y)
259+
val px = find(x)
260+
val py = find(y)
255261
if (px == py) return px
256262
val (root, child) = if (rank(px) < rank(py)) (py, px) else (px, py)
257263
parent(child) = root
@@ -331,8 +337,8 @@ class AgglomerativeBregman(override val uid: String)
331337
val assignments = Array.tabulate(n)(i => find(i))
332338

333339
// Relabel to 0..k-1
334-
val uniqueLabels = assignments.distinct.sorted
335-
val labelMap = uniqueLabels.zipWithIndex.toMap
340+
val uniqueLabels = assignments.distinct.sorted
341+
val labelMap = uniqueLabels.zipWithIndex.toMap
336342
val finalAssignments = assignments.map(labelMap)
337343

338344
(finalAssignments, dendrogram.toArray, mergeDistances.toArray)
@@ -380,8 +386,8 @@ class AgglomerativeBregman(override val uid: String)
380386
val centroidB = computeCentroid(clusterB, points, kernel)
381387

382388
// ESS increase = |A||B|/(|A|+|B|) * ||μ_A - μ_B||²
383-
val nA = clusterA.size.toDouble
384-
val nB = clusterB.size.toDouble
389+
val nA = clusterA.size.toDouble
390+
val nB = clusterB.size.toDouble
385391
val dist = kernel.divergence(centroidA, centroidB)
386392
(nA * nB / (nA + nB)) * dist
387393

@@ -561,24 +567,26 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
561567
val dendrogramData = instance.dendrogram.zipWithIndex.map { case (m, i) =>
562568
(i, m.cluster1, m.cluster2, m.merged, m.distance)
563569
}.toSeq
564-
spark.createDataFrame(dendrogramData)
570+
spark
571+
.createDataFrame(dendrogramData)
565572
.toDF("id", "cluster1", "cluster2", "merged", "distance")
566-
.write.parquet(s"$path/dendrogram")
573+
.write
574+
.parquet(s"$path/dendrogram")
567575

568576
val params: Map[String, Any] = Map(
569-
"k" -> instance.k,
570-
"featuresCol" -> instance.getOrDefault(instance.featuresCol),
577+
"k" -> instance.k,
578+
"featuresCol" -> instance.getOrDefault(instance.featuresCol),
571579
"predictionCol" -> instance.getOrDefault(instance.predictionCol),
572-
"divergence" -> instance.modelDivergence,
573-
"smoothing" -> instance.modelSmoothing,
574-
"linkage" -> instance.modelLinkage
580+
"divergence" -> instance.modelDivergence,
581+
"smoothing" -> instance.modelSmoothing,
582+
"linkage" -> instance.modelLinkage
575583
)
576584

577585
val k = instance.k
578586
val dim = instance.clusterCenters.headOption.map(_.size).getOrElse(0)
579587

580588
implicit val formats: DefaultFormats.type = DefaultFormats
581-
val metaObj: Map[String, Any] = Map(
589+
val metaObj: Map[String, Any] = Map(
582590
"layoutVersion" -> LayoutVersion,
583591
"algo" -> "AgglomerativeBregmanModel",
584592
"sparkMLVersion" -> org.apache.spark.SPARK_VERSION,
@@ -609,7 +617,9 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
609617
}
610618
}
611619

612-
private class AgglomerativeBregmanModelReader extends MLReader[AgglomerativeBregmanModel] with Logging {
620+
private class AgglomerativeBregmanModelReader
621+
extends MLReader[AgglomerativeBregmanModel]
622+
with Logging {
613623
import com.massivedatascience.clusterer.ml.df.persistence.PersistenceLayoutV1._
614624
import org.json4s.DefaultFormats
615625
import org.json4s.jackson.JsonMethods
@@ -618,9 +628,9 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
618628
val spark = sparkSession
619629
logInfo(s"Loading AgglomerativeBregmanModel from $path")
620630

621-
val metaStr = readMetadata(path)
631+
val metaStr = readMetadata(path)
622632
implicit val formats: DefaultFormats.type = DefaultFormats
623-
val metaJ = JsonMethods.parse(metaStr)
633+
val metaJ = JsonMethods.parse(metaStr)
624634

625635
val layoutVersion = (metaJ \ "layoutVersion").extract[Int]
626636
val k = (metaJ \ "k").extract[Int]
@@ -633,7 +643,8 @@ object AgglomerativeBregmanModel extends MLReadable[AgglomerativeBregmanModel] {
633643

634644
val centers = rows.sortBy(_.getInt(0)).map(_.getAs[Vector]("vector"))
635645

636-
val dendrogram = spark.read.parquet(s"$path/dendrogram")
646+
val dendrogram = spark.read
647+
.parquet(s"$path/dendrogram")
637648
.orderBy("id")
638649
.collect()
639650
.map(r => MergeStep(r.getInt(1), r.getInt(2), r.getInt(3), r.getDouble(4)))

src/main/scala/com/massivedatascience/clusterer/ml/BregmanMixtureModel.scala

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,15 @@ import org.apache.spark.internal.Logging
2222
import org.apache.spark.ml.{ Estimator, Model }
2323
import org.apache.spark.ml.linalg.{ Vector, Vectors }
2424
import org.apache.spark.ml.param._
25-
import org.apache.spark.ml.util.{ DefaultParamsReadable, DefaultParamsWritable, Identifiable, MLReadable, MLReader, MLWritable, MLWriter }
25+
import org.apache.spark.ml.util.{
26+
DefaultParamsReadable,
27+
DefaultParamsWritable,
28+
Identifiable,
29+
MLReadable,
30+
MLReader,
31+
MLWritable,
32+
MLWriter
33+
}
2634
import org.apache.spark.sql.{ DataFrame, Dataset }
2735
import org.apache.spark.sql.functions._
2836
import org.apache.spark.sql.types.StructType
@@ -35,8 +43,7 @@ trait BregmanMixtureParams extends GeneralizedKMeansParams {
3543
*/
3644
def getNumComponents: Int = $(k)
3745

38-
/** Regularization parameter (Dirichlet prior for component weights).
39-
* 0 = no regularization (MLE)
46+
/** Regularization parameter (Dirichlet prior for component weights). 0 = no regularization (MLE)
4047
* > 0 = MAP estimation with symmetric Dirichlet prior
4148
*/
4249
final val regularization: DoubleParam = new DoubleParam(
@@ -63,8 +70,8 @@ trait BregmanMixtureParams extends GeneralizedKMeansParams {
6370

6471
/** Bregman Mixture Model - probabilistic clustering via EM algorithm.
6572
*
66-
* Fits a mixture model where each component is parameterized by an
67-
* exponential family distribution corresponding to the chosen Bregman divergence:
73+
* Fits a mixture model where each component is parameterized by an exponential family distribution
74+
* corresponding to the chosen Bregman divergence:
6875
*
6976
* - Squared Euclidean → Gaussian mixture
7077
* - KL divergence → Multinomial mixture
@@ -127,18 +134,18 @@ class BregmanMixture(override val uid: String)
127134
def this() = this(Identifiable.randomUID("bregmanmixture"))
128135

129136
// Parameter setters
130-
def setK(value: Int): this.type = set(k, value)
131-
def setNumComponents(value: Int): this.type = set(k, value)
132-
def setDivergence(value: String): this.type = set(divergence, value)
133-
def setSmoothing(value: Double): this.type = set(smoothing, value)
137+
def setK(value: Int): this.type = set(k, value)
138+
def setNumComponents(value: Int): this.type = set(k, value)
139+
def setDivergence(value: String): this.type = set(divergence, value)
140+
def setSmoothing(value: Double): this.type = set(smoothing, value)
134141
def setRegularization(value: Double): this.type = set(regularization, value)
135-
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
136-
def setPredictionCol(value: String): this.type = set(predictionCol, value)
142+
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
143+
def setPredictionCol(value: String): this.type = set(predictionCol, value)
137144
def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
138-
def setWeightCol(value: String): this.type = set(weightCol, value)
139-
def setMaxIter(value: Int): this.type = set(maxIter, value)
140-
def setTol(value: Double): this.type = set(tol, value)
141-
def setSeed(value: Long): this.type = set(seed, value)
145+
def setWeightCol(value: String): this.type = set(weightCol, value)
146+
def setMaxIter(value: Int): this.type = set(maxIter, value)
147+
def setTol(value: Double): this.type = set(tol, value)
148+
def setSeed(value: Long): this.type = set(seed, value)
142149

143150
override def fit(dataset: Dataset[_]): BregmanMixtureModelInstance = {
144151
transformSchema(dataset.schema, logging = true)
@@ -176,7 +183,7 @@ class BregmanMixture(override val uid: String)
176183

177184
// Run EM
178185
val emIterator = new BregmanEMIterator()
179-
val result = emIterator.runEM(
186+
val result = emIterator.runEM(
180187
df,
181188
$(featuresCol),
182189
if (hasWeightCol) Some($(weightCol)) else None,
@@ -187,12 +194,13 @@ class BregmanMixture(override val uid: String)
187194
val elapsed = System.currentTimeMillis() - startTime
188195
logInfo(
189196
s"Bregman Mixture Model completed: ${result.iterations} iterations, " +
190-
s"converged=${result.converged}, finalLogLik=${result.logLikelihoodHistory.lastOption.getOrElse(Double.NaN)}"
197+
s"converged=${result.converged}, finalLogLik=${result.logLikelihoodHistory.lastOption
198+
.getOrElse(Double.NaN)}"
191199
)
192200

193201
// Create model
194202
val centersAsVectors = result.centers.map(Vectors.dense)
195-
val model = new BregmanMixtureModelInstance(
203+
val model = new BregmanMixtureModelInstance(
196204
uid,
197205
centersAsVectors,
198206
result.weights,
@@ -315,7 +323,8 @@ class BregmanMixtureModelInstance(
315323
(prediction, Vectors.dense(probs))
316324
}
317325

318-
val result = df.withColumn("_bmm_result", predictUDF(col($(featuresCol))))
326+
val result = df
327+
.withColumn("_bmm_result", predictUDF(col($(featuresCol))))
319328
.withColumn($(predictionCol), col("_bmm_result._1"))
320329
.withColumn($(probabilityCol), col("_bmm_result._2"))
321330
.drop("_bmm_result")
@@ -343,17 +352,14 @@ class BregmanMixtureModelInstance(
343352
maxLogProb + math.log(expSum)
344353
}
345354

346-
df.select(logLikUDF(col($(featuresCol))).as("loglik"))
347-
.agg(sum("loglik"))
348-
.head()
349-
.getDouble(0)
355+
df.select(logLikUDF(col($(featuresCol))).as("loglik")).agg(sum("loglik")).head().getDouble(0)
350356
}
351357

352358
/** Compute BIC (Bayesian Information Criterion). Lower is better. */
353359
def bic(dataset: Dataset[_]): Double = {
354-
val n = dataset.count()
355-
val logLik = logLikelihood(dataset)
356-
val dim = means.headOption.map(_.size).getOrElse(0)
360+
val n = dataset.count()
361+
val logLik = logLikelihood(dataset)
362+
val dim = means.headOption.map(_.size).getOrElse(0)
357363
val numParams = numComponents * dim + numComponents - 1 // means + weights
358364
-2 * logLik + numParams * math.log(n.toDouble)
359365
}
@@ -415,7 +421,7 @@ object BregmanMixtureModelInstance extends MLReadable[BregmanMixtureModelInstanc
415421
val dim = instance.means.headOption.map(_.size).getOrElse(0)
416422

417423
implicit val formats: DefaultFormats.type = DefaultFormats
418-
val metaObj: Map[String, Any] = Map(
424+
val metaObj: Map[String, Any] = Map(
419425
"layoutVersion" -> LayoutVersion,
420426
"algo" -> "BregmanMixtureModelInstance",
421427
"sparkMLVersion" -> org.apache.spark.SPARK_VERSION,
@@ -441,7 +447,9 @@ object BregmanMixtureModelInstance extends MLReadable[BregmanMixtureModelInstanc
441447
}
442448
}
443449

444-
private class BregmanMixtureModelReader extends MLReader[BregmanMixtureModelInstance] with Logging {
450+
private class BregmanMixtureModelReader
451+
extends MLReader[BregmanMixtureModelInstance]
452+
with Logging {
445453
import com.massivedatascience.clusterer.ml.df.persistence.PersistenceLayoutV1._
446454
import org.json4s.DefaultFormats
447455
import org.json4s.jackson.JsonMethods
@@ -450,9 +458,9 @@ object BregmanMixtureModelInstance extends MLReadable[BregmanMixtureModelInstanc
450458
val spark = sparkSession
451459
logInfo(s"Loading BregmanMixtureModelInstance from $path")
452460

453-
val metaStr = readMetadata(path)
461+
val metaStr = readMetadata(path)
454462
implicit val formats: DefaultFormats.type = DefaultFormats
455-
val metaJ = JsonMethods.parse(metaStr)
463+
val metaJ = JsonMethods.parse(metaStr)
456464

457465
val layoutVersion = (metaJ \ "layoutVersion").extract[Int]
458466
val k = (metaJ \ "k").extract[Int]

src/main/scala/com/massivedatascience/clusterer/ml/ClusteringModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import org.apache.spark.ml.linalg.Vector
66

77
/** Shared training summary handling for clustering models.
88
*
9-
* Models mix this in to get consistent summary/hasSummary behavior while
10-
* keeping the summary payload optionally available for persisted models.
9+
* Models mix this in to get consistent summary/hasSummary behavior while keeping the summary
10+
* payload optionally available for persisted models.
1111
*/
1212
trait HasTrainingSummary extends Params { self: Logging =>
1313

0 commit comments

Comments
 (0)