Skip to content

Commit 32110fb

Browse files
committed
Fix RobustKMeans divergence case mismatch and SparseKMeans L1 ClassCastException
- ClusteringOps: normalize all divergence lookups with toLowerCase, add aliases (se, is, logisticloss) - ClusteringOps: align isValidDivergence and createUpdateStrategy with case-insensitive matching - SparseKMeans.updateCenters: pattern match BregmanKernel for grad-based update; use MedianUpdateStrategy for L1 (component-wise median, not mean)
1 parent 7cff629 commit 32110fb

File tree

2 files changed

+68
-42
lines changed

2 files changed

+68
-42
lines changed

src/main/scala/com/massivedatascience/clusterer/ml/SparseKMeans.scala

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -369,39 +369,62 @@ class SparseKMeans(override val uid: String)
369369
kernel: ClusteringKernel,
370370
numClusters: Int
371371
): Array[Vector] = {
372-
val bregmanKernel = kernel.asInstanceOf[BregmanKernel]
373-
val bcKernel = assigned.sparkSession.sparkContext.broadcast(bregmanKernel)
374-
375-
val gradUDF = udf { (features: Vector) =>
376-
bcKernel.value.grad(features).toArray
377-
}
372+
val dim = assigned.select($(featuresCol)).head().getAs[Vector](0).size
378373

379-
val withGrad = assigned.withColumn("_grad", gradUDF(col($(featuresCol))))
374+
kernel match {
375+
case bk: BregmanKernel =>
376+
// Use gradient-based update for Bregman divergences
377+
val bcKernel = assigned.sparkSession.sparkContext.broadcast(bk)
380378

381-
val dim = assigned.select($(featuresCol)).head().getAs[Vector](0).size
379+
val gradUDF = udf { (features: Vector) =>
380+
bcKernel.value.grad(features).toArray
381+
}
382382

383-
val aggregated = withGrad
384-
.groupBy("_cluster")
385-
.agg(
386-
count("*").as("count"),
387-
array((0 until dim).map(i => sum(element_at(col("_grad"), i + 1))): _*).as("grad_sum")
388-
)
389-
.collect()
383+
val withGrad = assigned.withColumn("_grad", gradUDF(col($(featuresCol))))
384+
385+
val aggregated = withGrad
386+
.groupBy("_cluster")
387+
.agg(
388+
count("*").as("count"),
389+
array((0 until dim).map(i => sum(element_at(col("_grad"), i + 1))): _*).as("grad_sum")
390+
)
391+
.collect()
392+
393+
val centers = Array.fill(numClusters)(Vectors.zeros(dim))
394+
aggregated.foreach { row =>
395+
val clusterId = row.getInt(0)
396+
if (clusterId >= 0 && clusterId < numClusters) {
397+
val count = row.getLong(1)
398+
val gradSum = row.getSeq[Double](2).toArray
399+
if (count > 0) {
400+
val avgGrad = Vectors.dense(gradSum.map(_ / count))
401+
centers(clusterId) = bcKernel.value.invGrad(avgGrad)
402+
}
403+
}
404+
}
390405

391-
val centers = Array.fill(numClusters)(Vectors.zeros(dim))
392-
aggregated.foreach { row =>
393-
val clusterId = row.getInt(0)
394-
if (clusterId >= 0 && clusterId < numClusters) {
395-
val count = row.getLong(1)
396-
val gradSum = row.getSeq[Double](2).toArray
397-
if (count > 0) {
398-
val avgGrad = Vectors.dense(gradSum.map(_ / count))
399-
centers(clusterId) = bcKernel.value.invGrad(avgGrad)
406+
centers
407+
408+
case _ =>
409+
// Non-Bregman kernels (e.g., L1): use component-wise median via
410+
// MedianUpdateStrategy, which correctly minimizes L1 distance.
411+
val updateStrategy = ClusteringOps.createUpdateStrategy("l1")
412+
val renamed = assigned.withColumnRenamed("_cluster", "cluster")
413+
val medianCenters = updateStrategy.update(
414+
renamed,
415+
$(featuresCol),
416+
weightCol = None,
417+
k = numClusters,
418+
kernel
419+
)
420+
// MedianUpdateStrategy may return fewer centers (drops empty clusters).
421+
// Pad back to numClusters with zeros.
422+
val centers = Array.fill(numClusters)(Vectors.zeros(dim))
423+
medianCenters.zipWithIndex.foreach { case (c, i) =>
424+
if (i < numClusters) centers(i) = Vectors.dense(c)
400425
}
401-
}
426+
centers
402427
}
403-
404-
centers
405428
}
406429

407430
private def computeMovement(

src/main/scala/com/massivedatascience/clusterer/ml/df/ClusteringOps.scala

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,16 @@ private[ml] object ClusteringOps extends Logging {
5050
"spherical"
5151
)
5252

53-
/** All supported divergence names and aliases. */
53+
/** All supported divergence names and aliases (lowercased for case-insensitive matching). */
5454
private val validDivergenceNames: Set[String] = Set(
55-
"squaredEuclidean",
55+
"squaredeuclidean",
56+
"se",
5657
"kl",
57-
"itakuraSaito",
58-
"generalizedI",
58+
"itakurasaito",
59+
"is",
60+
"generalizedi",
5961
"logistic",
62+
"logisticloss",
6063
"l1",
6164
"manhattan",
6265
"spherical",
@@ -75,15 +78,15 @@ private[ml] object ClusteringOps extends Logging {
7578
* if divergence name is unknown
7679
*/
7780
def createKernel(divergence: String, smoothing: Double = 1e-10): ClusteringKernel = {
78-
divergence match {
79-
case "squaredEuclidean" => new SquaredEuclideanKernel()
80-
case "kl" => new KLDivergenceKernel(smoothing)
81-
case "itakuraSaito" => new ItakuraSaitoKernel(smoothing)
82-
case "generalizedI" => new GeneralizedIDivergenceKernel(smoothing)
83-
case "logistic" => new LogisticLossKernel(smoothing)
84-
case "l1" | "manhattan" => new L1Kernel()
85-
case "spherical" | "cosine" => new SphericalKernel()
86-
case _ =>
81+
divergence.toLowerCase match {
82+
case "squaredeuclidean" | "se" => new SquaredEuclideanKernel()
83+
case "kl" => new KLDivergenceKernel(smoothing)
84+
case "itakurasaito" | "is" => new ItakuraSaitoKernel(smoothing)
85+
case "generalizedi" => new GeneralizedIDivergenceKernel(smoothing)
86+
case "logistic" | "logisticloss" => new LogisticLossKernel(smoothing)
87+
case "l1" | "manhattan" => new L1Kernel()
88+
case "spherical" | "cosine" => new SphericalKernel()
89+
case _ =>
8790
throw new IllegalArgumentException(
8891
s"Unknown divergence: '$divergence'. " +
8992
s"Valid options: ${supportedDivergences.mkString(", ")}"
@@ -122,7 +125,7 @@ private[ml] object ClusteringOps extends Logging {
122125
* configured UpdateStrategy
123126
*/
124127
def createUpdateStrategy(divergence: String): UpdateStrategy = {
125-
divergence match {
128+
divergence.toLowerCase match {
126129
case "l1" | "manhattan" => new MedianUpdateStrategy()
127130
case _ => new GradMeanUDAFUpdate()
128131
}
@@ -181,5 +184,5 @@ private[ml] object ClusteringOps extends Logging {
181184
/** Check if a divergence name is valid.
182185
*/
183186
def isValidDivergence(divergence: String): Boolean =
184-
validDivergenceNames.contains(divergence)
187+
validDivergenceNames.contains(divergence.toLowerCase)
185188
}

0 commit comments

Comments
 (0)