Skip to content

Commit fd643d5

Browse files
derrickburnsclaude
andcommitted
refactor: Introduce ClusteringKernel type hierarchy, shared orchestration, and center initialization
L1Kernel incorrectly extended BregmanKernel despite having no valid gradient or inverse gradient, which could silently produce wrong centers when paired with GradMeanUDAFUpdate. This refactor introduces a proper type hierarchy and centralizes duplicated factory/initialization code across estimators. Type hierarchy: - New ClusteringKernel root trait (divergence, validate, name) - BregmanKernel extends ClusteringKernel (adds grad/invGrad) - L1Kernel reclassified to extend ClusteringKernel directly - SparseClusteringKernel trait for sparse-optimized non-Bregman kernels - GradMeanUDAFUpdate now has runtime require guard for BregmanKernel - All consumer signatures widened from BregmanKernel to ClusteringKernel Shared orchestration (ClusteringOps): - Single source of truth for createKernel, createAssignmentStrategy, createUpdateStrategy, createEmptyClusterHandler, validateDomain - 10 estimators updated to delegate to ClusteringOps Center initialization (CenterInitializer): - Extracted k-means++ and random initialization into shared utility - GeneralizedKMeans and BalancedKMeans now share initialization code 44 files changed, 235 insertions, 569 deletions. All 249 non-Spark tests pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4bb21d5 commit fd643d5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+698
-569
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Changed
11+
- **Kernel type hierarchy refactor:** Introduced `ClusteringKernel` as the root trait for all clustering kernels. `BregmanKernel` now extends `ClusteringKernel` and adds `grad`/`invGrad` declarations. `L1Kernel` reclassified to extend `ClusteringKernel` directly (not `BregmanKernel`) since L1 has no valid gradient/inverse gradient. All consumer signatures widened from `BregmanKernel` to `ClusteringKernel`.
12+
- **GradMeanUDAFUpdate runtime guard:** Now throws `IllegalArgumentException` with actionable message when passed a non-Bregman kernel (e.g., L1Kernel), preventing silent wrong-answer bugs.
13+
- **Shared orchestration:** Created `ClusteringOps` object centralizing `createKernel`, `createAssignmentStrategy`, `createUpdateStrategy`, `createEmptyClusterHandler`, and `validateDomain` factory methods. All estimators (GeneralizedKMeans, BisectingKMeans, BalancedKMeans, DPMeans, MiniBatchKMeans, SoftKMeans, StreamingKMeans, CoresetKMeans, ConstrainedKMeans, RobustKMeans) now delegate to `ClusteringOps`.
14+
- **Shared initialization:** Created `CenterInitializer` utility extracting k-means++ and random initialization from GeneralizedKMeans. GeneralizedKMeans and BalancedKMeans now share the same initialization code.
15+
- **Sparse kernel hierarchy:** Added `SparseClusteringKernel` trait; `SparseBregmanKernel` extends both `BregmanKernel` and `SparseClusteringKernel`. `SparseL1Kernel` correctly extends `L1Kernel with SparseClusteringKernel`.
16+
- Backward compatibility maintained via type aliases in package objects.
17+
1018
### Added
1119
- Comprehensive CI validation DAG with cross-version testing
1220
- SECURITY.md with vulnerability reporting guidelines

ROADMAP.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ These frameworks unblock multiple roadmap items; prefer delivering them before d
143143
| 2024 | Keep L1 listed alongside Bregman divergences | Practical utility outweighs theoretical purity |
144144
| 2025-12-15 | Prioritize robust/sparse/multi-view work next | Highest user demand and unlocks downstream variants |
145145
| 2025-12-15 | Maintain kernels in a single module (`BregmanKernel.scala`) | Consistency and discoverability |
146+
| 2026-02-11 | Introduced `ClusteringKernel` as root trait; reclassified L1Kernel | L1 is not a true Bregman divergence; type system now prevents misuse with GradMeanUDAFUpdate |
147+
| 2026-02-11 | Created `ClusteringOps` and `CenterInitializer` shared utilities | Centralized factory methods and k-means++ initialization; eliminates copy-paste across estimators |
146148
| 2025-12-15 | Use phased delivery for accelerations and new iterators | Keep CI stable while iterating |
147149
| 2025-12-16 | Created `KernelFactory` for unified kernel creation | Single API for dense/sparse kernels, reduces duplication |
148150
| 2025-12-16 | Moved assignment strategies to `impl/` subpackage | Better organization, backward-compatible via type aliases |

src/main/scala/com/massivedatascience/clusterer/ml/AgglomerativeBregman.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package com.massivedatascience.clusterer.ml
1919

20-
import com.massivedatascience.clusterer.ml.df.BregmanKernel
20+
import com.massivedatascience.clusterer.ml.df.{ BregmanKernel, ClusteringKernel }
2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.ml.{ Estimator, Model }
2323
import org.apache.spark.ml.linalg.{ Vector, Vectors }
@@ -434,7 +434,7 @@ class AgglomerativeBregman(override val uid: String)
434434
}
435435

436436
private def createKernel(): BregmanKernel = {
437-
BregmanKernel.create($(divergence), $(smoothing))
437+
BregmanKernel.create($(divergence), $(smoothing)).asInstanceOf[BregmanKernel]
438438
}
439439

440440
override def copy(extra: ParamMap): AgglomerativeBregman = defaultCopy(extra)
@@ -481,7 +481,7 @@ class AgglomerativeBregmanModel(
481481
private[ml] var modelDivergence: String = "squaredEuclidean"
482482
private[ml] var modelSmoothing: Double = 1e-10
483483
private[ml] var modelLinkage: String = "average"
484-
private[ml] var kernel: BregmanKernel = _
484+
private[ml] var kernel: ClusteringKernel = _
485485

486486
/** Cluster centers as vectors for downstream consumers/tests. */
487487
def clusterCentersAsVectors: Array[Vector] = clusterCenters

src/main/scala/com/massivedatascience/clusterer/ml/BalancedKMeans.scala

Lines changed: 15 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package com.massivedatascience.clusterer.ml
1919

2020
import com.massivedatascience.clusterer.ml.df._
21-
import com.massivedatascience.clusterer.ml.df.kernels._
21+
import com.massivedatascience.clusterer.ml.df.kernels.ClusteringKernel
2222
import org.apache.spark.internal.Logging
2323
import org.apache.spark.ml.Estimator
2424
import org.apache.spark.ml.linalg.{ Vector, Vectors }
@@ -191,10 +191,21 @@ class BalancedKMeans(override val uid: String)
191191
)
192192

193193
// Create kernel
194-
val kernel = createKernel($(divergence), $(smoothing))
194+
val kernel = ClusteringOps.createKernel($(divergence), $(smoothing))
195195

196196
// Initialize centers
197-
val initialCenters = initializeCenters(df, $(featuresCol), kernel)
197+
val initialCenters = CenterInitializer
198+
.initialize(
199+
df,
200+
$(featuresCol),
201+
weightCol = None,
202+
$(k),
203+
$(initMode),
204+
$(initSteps),
205+
$(seed),
206+
kernel
207+
)
208+
.map(arr => Vectors.dense(arr))
198209

199210
logInfo(s"Initialized ${initialCenters.length} centers using ${$(initMode)}")
200211

@@ -239,7 +250,7 @@ class BalancedKMeans(override val uid: String)
239250
private def runBalancedLloyds(
240251
df: DataFrame,
241252
initialCenters: Array[Vector],
242-
kernel: BregmanKernel,
253+
kernel: ClusteringKernel,
243254
minSize: Int,
244255
maxSize: Int
245256
): LloydResult = {
@@ -502,68 +513,6 @@ class BalancedKMeans(override val uid: String)
502513
withDistances.withColumn("_assignment", assignUdf(col("_row_id")))
503514
}
504515

505-
private def createKernel(divergenceName: String, smoothing: Double): BregmanKernel = {
506-
divergenceName.toLowerCase match {
507-
case "squaredeuclidean" | "se" | "euclidean" => new SquaredEuclideanKernel()
508-
case "kl" | "kullbackleibler" => new KLDivergenceKernel(smoothing)
509-
case "itakurasaito" | "is" => new ItakuraSaitoKernel(smoothing)
510-
case "l1" | "manhattan" => new L1Kernel()
511-
case "spherical" | "cosine" => new SphericalKernel()
512-
case "generalizedi" | "gi" => new GeneralizedIDivergenceKernel(smoothing)
513-
case "logistic" => new LogisticLossKernel()
514-
case other => throw new IllegalArgumentException(s"Unknown divergence: $other")
515-
}
516-
}
517-
518-
private def initializeCenters(
519-
df: DataFrame,
520-
featuresCol: String,
521-
kernel: BregmanKernel
522-
): Array[Vector] = {
523-
val rng = new Random($(seed))
524-
525-
$(initMode).toLowerCase match {
526-
case "random" =>
527-
val fraction = math.min(1.0, $(k).toDouble / df.count() * 10)
528-
df.select(featuresCol)
529-
.sample(withReplacement = false, fraction, $(seed))
530-
.limit($(k))
531-
.collect()
532-
.map(_.getAs[Vector](0))
533-
534-
case "k-means||" | "kmeansparallel" =>
535-
// Simplified k-means|| initialization
536-
val allPoints = df.select(featuresCol).collect().map(_.getAs[Vector](0))
537-
if (allPoints.length <= $(k)) {
538-
allPoints
539-
} else {
540-
val centers = scala.collection.mutable.ArrayBuffer.empty[Vector]
541-
centers += allPoints(rng.nextInt(allPoints.length))
542-
543-
while (centers.length < $(k)) {
544-
val currentCenters = centers.toArray
545-
val distances: Array[Double] = allPoints.map { point =>
546-
val dists: Array[Double] = currentCenters.map(c => kernel.divergence(point, c))
547-
dists.min
548-
}
549-
val totalDist: Double = distances.sum
550-
if (totalDist > 0) {
551-
val probabilities: Array[Double] = distances.map(d => d / totalDist)
552-
val cumProbs: Array[Double] = probabilities.scanLeft(0.0)((a, b) => a + b).tail
553-
val r = rng.nextDouble()
554-
val idx = cumProbs.indexWhere(_ >= r)
555-
centers += allPoints(if (idx >= 0) idx else allPoints.length - 1)
556-
} else {
557-
centers += allPoints(rng.nextInt(allPoints.length))
558-
}
559-
}
560-
centers.toArray
561-
}
562-
563-
case other =>
564-
throw new IllegalArgumentException(s"Unknown initialization mode: $other")
565-
}
566-
}
567516

568517
override def transformSchema(schema: StructType): StructType = {
569518
require(

src/main/scala/com/massivedatascience/clusterer/ml/BisectingKMeans.scala

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,15 @@ class BisectingKMeans(override val uid: String)
144144
)
145145

146146
// Validate input data domain requirements for the selected divergence
147-
com.massivedatascience.util.DivergenceDomainValidator.validateDataFrame(
147+
ClusteringOps.validateDomain(
148148
df,
149149
$(featuresCol),
150150
$(divergence),
151-
maxSamples = Some(1000)
151+
maxSamples = 1000
152152
)
153153

154154
// Create kernel
155-
val kernel = createKernel($(divergence), $(smoothing))
155+
val kernel = ClusteringOps.createKernel($(divergence), $(smoothing))
156156

157157
// Bisecting algorithm with timing
158158
val startTime = System.currentTimeMillis()
@@ -203,7 +203,7 @@ class BisectingKMeans(override val uid: String)
203203
df: DataFrame,
204204
featuresCol: String,
205205
weightCol: Option[String],
206-
kernel: BregmanKernel
206+
kernel: ClusteringKernel
207207
): (Array[Array[Double]], Int) = {
208208

209209
val targetK = $(k)
@@ -324,7 +324,7 @@ class BisectingKMeans(override val uid: String)
324324
clusterData: DataFrame,
325325
featuresCol: String,
326326
weightCol: Option[String],
327-
kernel: BregmanKernel
327+
kernel: ClusteringKernel
328328
): (Array[Double], Array[Double]) = {
329329

330330
// Drop the "cluster" column if it exists to avoid conflicts with assignment strategy
@@ -348,8 +348,8 @@ class BisectingKMeans(override val uid: String)
348348
)
349349

350350
// Create strategies for k=2 clustering
351-
val assigner = createAssignmentStrategy("auto")
352-
val updater = createUpdateStrategy($(divergence))
351+
val assigner = ClusteringOps.createAssignmentStrategy("auto")
352+
val updater = ClusteringOps.createUpdateStrategy($(divergence))
353353

354354
// Run Lloyd's for a few iterations
355355
var iteration = 0
@@ -389,10 +389,10 @@ class BisectingKMeans(override val uid: String)
389389
data: DataFrame,
390390
featuresCol: String,
391391
weightCol: Option[String],
392-
kernel: BregmanKernel
392+
kernel: ClusteringKernel
393393
): Array[Double] = {
394394

395-
val updater = createUpdateStrategy($(divergence))
395+
val updater = ClusteringOps.createUpdateStrategy($(divergence))
396396
val centers = updater.update(
397397
data.withColumn("cluster", lit(0)),
398398
featuresCol,
@@ -404,41 +404,6 @@ class BisectingKMeans(override val uid: String)
404404
if (centers.nonEmpty) centers(0) else Array.empty[Double]
405405
}
406406

407-
/** Create Bregman kernel based on divergence name.
408-
*/
409-
private def createKernel(divName: String, smooth: Double): BregmanKernel = {
410-
divName match {
411-
case "squaredEuclidean" => new SquaredEuclideanKernel()
412-
case "kl" => new KLDivergenceKernel(smooth)
413-
case "itakuraSaito" => new ItakuraSaitoKernel(smooth)
414-
case "generalizedI" => new GeneralizedIDivergenceKernel(smooth)
415-
case "logistic" => new LogisticLossKernel(smooth)
416-
case "l1" | "manhattan" => new L1Kernel()
417-
case "spherical" | "cosine" => new SphericalKernel()
418-
case _ => throw new IllegalArgumentException(s"Unknown divergence: $divName")
419-
}
420-
}
421-
422-
/** Create assignment strategy.
423-
*/
424-
private def createAssignmentStrategy(strategy: String): AssignmentStrategy = {
425-
strategy match {
426-
case "broadcast" => new BroadcastUDFAssignment()
427-
case "crossJoin" => new SECrossJoinAssignment()
428-
case "auto" => new AutoAssignment()
429-
case _ => throw new IllegalArgumentException(s"Unknown assignment strategy: $strategy")
430-
}
431-
}
432-
433-
/** Create update strategy based on divergence.
434-
*/
435-
private def createUpdateStrategy(divName: String): UpdateStrategy = {
436-
divName match {
437-
case "l1" | "manhattan" => new MedianUpdateStrategy()
438-
case _ => new GradMeanUDAFUpdate()
439-
}
440-
}
441-
442407
override def transformSchema(schema: StructType): StructType = {
443408
validateAndTransformSchema(schema)
444409
}

src/main/scala/com/massivedatascience/clusterer/ml/BregmanMixtureModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class BregmanMixture(override val uid: String)
233233
}
234234
}
235235

236-
private def createKernel(): BregmanKernel = {
236+
private def createKernel(): ClusteringKernel = {
237237
BregmanKernel.create($(divergence), $(smoothing))
238238
}
239239

@@ -270,7 +270,7 @@ class BregmanMixtureModelInstance(
270270
override val uid: String,
271271
val means: Array[Vector],
272272
val weights: Array[Double],
273-
val kernel: BregmanKernel
273+
val kernel: ClusteringKernel
274274
) extends Model[BregmanMixtureModelInstance]
275275
with BregmanMixtureParams
276276
with MLWritable

src/main/scala/com/massivedatascience/clusterer/ml/CoClustering.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ class CoClustering(override val uid: String)
296296
rowClusters: Map[Long, Int],
297297
colClusters: Map[Long, Int],
298298
blockCenters: Array[Array[Double]],
299-
kernel: BregmanKernel,
299+
kernel: ClusteringKernel,
300300
iterations: Int,
301301
objective: Double
302302
): CoClusteringModel = {
@@ -343,7 +343,7 @@ class CoClustering(override val uid: String)
343343
rowClusters: Map[Long, Int],
344344
colClusters: Map[Long, Int],
345345
blockCenters: Array[Array[Double]],
346-
kernel: BregmanKernel
346+
kernel: ClusteringKernel
347347
): Double = {
348348

349349
val bcBlockCenters = df.sparkSession.sparkContext.broadcast(blockCenters)
@@ -373,7 +373,7 @@ class CoClustering(override val uid: String)
373373
df: DataFrame,
374374
colClusters: Map[Long, Int],
375375
blockCenters: Array[Array[Double]],
376-
kernel: BregmanKernel
376+
kernel: ClusteringKernel
377377
): Map[Long, Int] = {
378378

379379
val spark = df.sparkSession
@@ -413,7 +413,7 @@ class CoClustering(override val uid: String)
413413
df: DataFrame,
414414
rowClusters: Map[Long, Int],
415415
blockCenters: Array[Array[Double]],
416-
kernel: BregmanKernel
416+
kernel: ClusteringKernel
417417
): Map[Long, Int] = {
418418

419419
val spark = df.sparkSession
@@ -471,7 +471,7 @@ class CoClusteringModel(
471471
val rowClusters: Map[Long, Int],
472472
val colClusters: Map[Long, Int],
473473
val blockCenters: Array[Array[Double]],
474-
private val kernel: BregmanKernel
474+
private val kernel: ClusteringKernel
475475
) extends Model[CoClusteringModel]
476476
with CoClusteringParams
477477
with MLWritable

0 commit comments

Comments
 (0)