Skip to content

Commit a1ae3cf

Browse files
derrickburnsclaude
andcommitted
refactor: Add KernelFactory and consolidate assignment strategies
- Add KernelFactory for unified dense/sparse kernel creation - Single API for all 8 Bregman divergences - Auto-selection based on data sparsity - Clear documentation of supported divergences - Move AcceleratedSEAssignment to strategies/impl/ subpackage - Better organization alongside other assignment strategies - Maintain backward compatibility via type aliases - Update models to use KernelFactory - GeneralizedKMeansModel uses KernelFactory for kernel creation - SoftKMeansModel persistence uses KernelFactory - Update package objects with re-exports for backward compatibility 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 15d5d57 commit a1ae3cf

File tree

8 files changed

+256
-23
lines changed

8 files changed

+256
-23
lines changed

src/main/scala/com/massivedatascience/clusterer/ml/GeneralizedKMeansModel.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,19 @@ class GeneralizedKMeansModel(
219219
/** Create Bregman kernel based on kernel name.
220220
*/
221221
private def createKernel(kernelName: String, smoothing: Double): BregmanKernel = {
222-
kernelName match {
223-
case "SquaredEuclidean" => new SquaredEuclideanKernel()
224-
case name if name.startsWith("KL(") => new KLDivergenceKernel(smoothing)
225-
case name if name.startsWith("ItakuraSaito(") => new ItakuraSaitoKernel(smoothing)
226-
case name if name.startsWith("GeneralizedI(") => new GeneralizedIDivergenceKernel(smoothing)
227-
case name if name.startsWith("LogisticLoss(") => new LogisticLossKernel(smoothing)
228-
case "L1" => new L1Kernel()
229-
case "Spherical" => new SphericalKernel()
230-
case _ => throw new IllegalArgumentException(s"Unknown kernel: $kernelName")
222+
import com.massivedatascience.clusterer.ml.df.kernels.KernelFactory
223+
// Map stored kernel names to divergence names for KernelFactory
224+
val divergence = kernelName match {
225+
case "SquaredEuclidean" => "squaredEuclidean"
226+
case name if name.startsWith("KL(") => "kl"
227+
case name if name.startsWith("ItakuraSaito(") => "itakuraSaito"
228+
case name if name.startsWith("GeneralizedI(") => "generalizedI"
229+
case name if name.startsWith("LogisticLoss(") => "logistic"
230+
case "L1" => "l1"
231+
case "Spherical" => "spherical"
232+
case other => other.toLowerCase
231233
}
234+
KernelFactory.create(divergence, smoothing = smoothing)
232235
}
233236

234237
override def write: MLWriter = new GeneralizedKMeansModel.GeneralizedKMeansModelWriter(this)

src/main/scala/com/massivedatascience/clusterer/ml/SoftKMeansModel.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,17 +267,8 @@ object SoftKMeansModel extends MLReadable[SoftKMeansModel] {
267267
val minMembership = (paramsJ \ "minMembership").extract[Double]
268268
val smoothing = (paramsJ \ "smoothing").extract[Double]
269269

270-
import com.massivedatascience.clusterer.ml.df._
271-
val kernel: BregmanKernel = divergence match {
272-
case "squaredEuclidean" => new SquaredEuclideanKernel()
273-
case "kl" => new KLDivergenceKernel(smoothing)
274-
case "itakuraSaito" => new ItakuraSaitoKernel(smoothing)
275-
case "generalizedI" => new GeneralizedIDivergenceKernel(smoothing)
276-
case "logistic" => new LogisticLossKernel(smoothing)
277-
case "l1" | "manhattan" => new L1Kernel()
278-
case "spherical" | "cosine" => new SphericalKernel()
279-
case _ => new SquaredEuclideanKernel()
280-
}
270+
import com.massivedatascience.clusterer.ml.df.kernels.KernelFactory
271+
val kernel = KernelFactory.create(divergence, smoothing = smoothing)
281272

282273
val model = new SoftKMeansModel(uid, centers, beta, minMembership, kernel)
283274
model.modelDivergence = divergence
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Licensed to the Massive Data Science and Derrick R. Burns under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* Massive Data Science and Derrick R. Burns licenses this file to You under the
6+
* Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package com.massivedatascience.clusterer.ml.df.kernels
19+
20+
/** Unified factory for creating Bregman kernels.
21+
*
22+
* This factory provides a single entry point for kernel creation with support for:
23+
* - Dense kernels (standard implementation)
24+
* - Sparse-optimized kernels (for high-dimensional sparse data)
25+
* - Auto-selection based on data characteristics
26+
*
27+
* ==Supported Divergences==
28+
*
29+
* | Name | Aliases | Sparse Support | Domain | Use Case |
30+
* |:-----------------|:----------------|:---------------|:--------|:--------------------------|
31+
* | squaredEuclidean | se, euclidean | Yes | R^n | General clustering |
32+
* | kl | kullbackLeibler | Yes | R+^n | Probability distributions |
33+
* | itakuraSaito | is | No | R+^n | Audio/spectrum analysis |
34+
* | generalizedI | genI | No | R+^n | Count data |
35+
* | logistic | - | No | [0,1]^n | Bounded probabilities |
36+
* | l1 | manhattan | Yes | R^n | Robust clustering |
37+
* | spherical | cosine | Yes | R^n | Text/documents |
38+
*
39+
* ==Example Usage==
40+
*
41+
* {{{
42+
* // Standard dense kernel
43+
* val seKernel = KernelFactory.create("squaredEuclidean")
44+
*
45+
* // Sparse-optimized kernel for text data
46+
* val klKernel = KernelFactory.create("kl", sparse = true)
47+
*
48+
* // Auto-select based on sparsity
49+
* val autoKernel = KernelFactory.forSparsity("squaredEuclidean", sparsityRatio = 0.1)
50+
* }}}
51+
*
52+
* @see
53+
* [[BregmanKernel]] for the kernel interface
54+
* @see
55+
* [[SparseBregmanKernel]] for sparse-optimized implementations
56+
*/
57+
object KernelFactory {
58+
59+
/** Canonical divergence names. */
60+
object Divergence {
61+
val SquaredEuclidean: String = "squaredEuclidean"
62+
val KL: String = "kl"
63+
val ItakuraSaito: String = "itakuraSaito"
64+
val GeneralizedI: String = "generalizedI"
65+
val Logistic: String = "logistic"
66+
val L1: String = "l1"
67+
val Spherical: String = "spherical"
68+
69+
/** All supported divergence names (canonical form). */
70+
val all: Seq[String] = Seq(
71+
SquaredEuclidean,
72+
KL,
73+
ItakuraSaito,
74+
GeneralizedI,
75+
Logistic,
76+
L1,
77+
Spherical
78+
)
79+
}
80+
81+
/** Divergences with sparse-optimized implementations. */
82+
val sparseSupported: Set[String] = Set(
83+
"squaredEuclidean",
84+
"se",
85+
"euclidean",
86+
"kl",
87+
"kullbackleibler",
88+
"l1",
89+
"manhattan",
90+
"spherical",
91+
"cosine"
92+
)
93+
94+
/** Create a Bregman kernel for the specified divergence.
95+
*
96+
* @param divergence
97+
* divergence name (case-insensitive)
98+
* @param sparse
99+
* if true, use sparse-optimized implementation when available
100+
* @param smoothing
101+
* smoothing parameter for divergences with domain constraints (KL, IS, etc.)
102+
* @return
103+
* configured BregmanKernel instance
104+
* @throws IllegalArgumentException
105+
* if divergence name is unknown
106+
*/
107+
def create(
108+
divergence: String,
109+
sparse: Boolean = false,
110+
smoothing: Double = 1e-10
111+
): BregmanKernel = {
112+
val normalized = divergence.toLowerCase.trim
113+
if (sparse && supportsSparse(normalized)) {
114+
createSparse(normalized, smoothing)
115+
} else {
116+
createDense(normalized, smoothing)
117+
}
118+
}
119+
120+
/** Create a kernel with auto-selection based on data sparsity.
121+
*
122+
* Selects sparse implementation when sparsity ratio is below threshold and sparse implementation
123+
* is available.
124+
*
125+
* @param divergence
126+
* divergence name
127+
* @param sparsityRatio
128+
* fraction of non-zero elements (0.0 = all zeros, 1.0 = dense)
129+
* @param smoothing
130+
* smoothing parameter
131+
* @param sparseThreshold
132+
* use sparse when sparsityRatio < this value (default 0.3)
133+
* @return
134+
* kernel optimized for the data sparsity
135+
*/
136+
def forSparsity(
137+
divergence: String,
138+
sparsityRatio: Double,
139+
smoothing: Double = 1e-10,
140+
sparseThreshold: Double = 0.3
141+
): BregmanKernel = {
142+
val useSparse = sparsityRatio < sparseThreshold && supportsSparse(divergence)
143+
create(divergence, sparse = useSparse, smoothing = smoothing)
144+
}
145+
146+
/** Check if sparse optimization is available for the divergence.
147+
*
148+
* @param divergence
149+
* divergence name (case-insensitive)
150+
* @return
151+
* true if sparse-optimized implementation exists
152+
*/
153+
def supportsSparse(divergence: String): Boolean =
154+
sparseSupported.contains(divergence.toLowerCase.trim)
155+
156+
/** Normalize divergence name to canonical form.
157+
*
158+
* @param divergence
159+
* any valid divergence name or alias
160+
* @return
161+
* canonical divergence name
162+
*/
163+
def normalize(divergence: String): String = divergence.toLowerCase.trim match {
164+
case "se" | "euclidean" => Divergence.SquaredEuclidean
165+
case "kullbackleibler" => Divergence.KL
166+
case "is" => Divergence.ItakuraSaito
167+
case "geni" => Divergence.GeneralizedI
168+
case "manhattan" => Divergence.L1
169+
case "cosine" => Divergence.Spherical
170+
case other => other
171+
}
172+
173+
/** Create a dense (standard) kernel implementation. */
174+
private def createDense(divergence: String, smoothing: Double): BregmanKernel =
175+
divergence match {
176+
case "squaredeuclidean" | "se" | "euclidean" => new SquaredEuclideanKernel()
177+
case "kl" | "kullbackleibler" => new KLDivergenceKernel(smoothing)
178+
case "itakurasaito" | "is" => new ItakuraSaitoKernel(smoothing)
179+
case "generalizedi" | "geni" => new GeneralizedIDivergenceKernel(smoothing)
180+
case "logistic" => new LogisticLossKernel(smoothing)
181+
case "l1" | "manhattan" => new L1Kernel()
182+
case "spherical" | "cosine" => new SphericalKernel()
183+
case other =>
184+
throw new IllegalArgumentException(
185+
s"Unknown divergence: '$other'. Supported: ${Divergence.all.mkString(", ")}"
186+
)
187+
}
188+
189+
/** Create a sparse-optimized kernel implementation. */
190+
private def createSparse(divergence: String, smoothing: Double): BregmanKernel =
191+
divergence match {
192+
case "squaredeuclidean" | "se" | "euclidean" => new SparseSEKernel()
193+
case "kl" | "kullbackleibler" => new SparseKLKernel(smoothing)
194+
case "l1" | "manhattan" => new SparseL1Kernel()
195+
case "spherical" | "cosine" => new SparseSphericalKernel()
196+
// Fall back to dense for others (no sparse optimization available)
197+
case other => createDense(other, smoothing)
198+
}
199+
}

src/main/scala/com/massivedatascience/clusterer/ml/df/kernels/package.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,38 @@ package com.massivedatascience.clusterer.ml.df
44
*
55
* This package contains kernel implementations for different Bregman divergences:
66
*
7+
* ==Factory==
8+
*
9+
* - [[kernels.KernelFactory]]: Unified factory for dense/sparse kernel selection
10+
*
11+
* ==Dense Kernels==
12+
*
713
* - [[kernels.SquaredEuclideanKernel]]: Standard k-means (L2 squared)
814
* - [[kernels.KLDivergenceKernel]]: Kullback-Leibler divergence
915
* - [[kernels.ItakuraSaitoKernel]]: Itakura-Saito divergence
1016
* - [[kernels.GeneralizedIDivergenceKernel]]: Generalized I-divergence
1117
* - [[kernels.LogisticLossKernel]]: Logistic loss
1218
* - [[kernels.L1Kernel]]: Manhattan distance (K-Medians)
1319
* - [[kernels.SphericalKernel]]: Cosine similarity (Spherical K-Means)
20+
*
21+
* ==Sparse-Optimized Kernels==
22+
*
23+
* - [[kernels.SparseSEKernel]]: Sparse Squared Euclidean
24+
* - [[kernels.SparseKLKernel]]: Sparse KL Divergence
25+
* - [[kernels.SparseL1Kernel]]: Sparse L1/Manhattan
26+
* - [[kernels.SparseSphericalKernel]]: Sparse Cosine/Spherical
27+
*
28+
* ==Usage==
29+
*
30+
* {{{
31+
* // Create kernel via factory
32+
* val kernel = KernelFactory.create("squaredEuclidean", sparse = false)
33+
*
34+
* // Auto-select based on data sparsity
35+
* val sparseKernel = KernelFactory.forSparsity("kl", sparsityRatio = 0.1)
36+
* }}}
1437
*/
1538
package object kernels {
1639
// All types are defined in their respective files
17-
// This package object serves as documentation
40+
// KernelFactory provides the main API for kernel creation
1841
}

src/main/scala/com/massivedatascience/clusterer/ml/df/package.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ package object df {
5353
type ChunkedBroadcastAssignment = strategies.ChunkedBroadcastAssignment
5454
type AdaptiveBroadcastAssignment = strategies.AdaptiveBroadcastAssignment
5555
type AutoAssignment = strategies.AutoAssignment
56+
type AcceleratedSEAssignment = strategies.impl.AcceleratedSEAssignment
57+
val AcceleratedAssignment = strategies.impl.AcceleratedAssignment
5658

5759
// Update strategies
5860
type UpdateStrategy = strategies.UpdateStrategy
@@ -81,4 +83,14 @@ package object df {
8183
type LogisticLossKernel = kernels.LogisticLossKernel
8284
type L1Kernel = kernels.L1Kernel
8385
type SphericalKernel = kernels.SphericalKernel
86+
87+
// Sparse kernel types
88+
type SparseBregmanKernel = kernels.SparseBregmanKernel
89+
type SparseSEKernel = kernels.SparseSEKernel
90+
type SparseKLKernel = kernels.SparseKLKernel
91+
type SparseL1Kernel = kernels.SparseL1Kernel
92+
type SparseSphericalKernel = kernels.SparseSphericalKernel
93+
94+
// Kernel factory
95+
val KernelFactory = kernels.KernelFactory
8496
}

src/main/scala/com/massivedatascience/clusterer/ml/df/AcceleratedSEAssignment.scala renamed to src/main/scala/com/massivedatascience/clusterer/ml/df/strategies/impl/AcceleratedSEAssignment.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
* limitations under the License.
1616
*/
1717

18-
package com.massivedatascience.clusterer.ml.df
18+
package com.massivedatascience.clusterer.ml.df.strategies.impl
1919

20+
import com.massivedatascience.clusterer.ml.df.BregmanKernel
21+
import com.massivedatascience.clusterer.ml.df.strategies.AssignmentStrategy
2022
import org.apache.spark.internal.Logging
2123
import org.apache.spark.ml.linalg.Vector
2224
import org.apache.spark.sql.DataFrame
@@ -28,7 +30,7 @@ import org.apache.spark.sql.functions._
2830
*
2931
* '''Key Insight (Elkan's Lemma 1):''' If d(x, c) ≤ d(c, c')/2, then d(x, c) ≤ d(x, c')
3032
*
31-
* This means: once we find a center c with distance d, we can skip any center c' where d(c, c')
33+
* This means: once we find a center c with distance d, we can skip any center c' where d(c, c') >=
3234
* 2*d (because the triangle inequality guarantees c' is farther).
3335
*
3436
* ==Algorithm==

src/main/scala/com/massivedatascience/clusterer/ml/df/strategies/impl/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package com.massivedatascience.clusterer.ml.df.strategies
99
* - [[ChunkedBroadcastAssignment]]: Memory-efficient chunked processing
1010
* - [[AdaptiveBroadcastAssignment]]: Memory-adaptive strategy
1111
* - [[AutoAssignment]]: Automatic strategy selection
12+
* - [[AcceleratedSEAssignment]]: Triangle-inequality accelerated SE assignment
1213
*/
1314
package object impl {
1415
// All implementations are defined in their respective files

src/main/scala/com/massivedatascience/clusterer/ml/df/strategies/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ package object strategies {
1919
type ChunkedBroadcastAssignment = impl.ChunkedBroadcastAssignment
2020
type AdaptiveBroadcastAssignment = impl.AdaptiveBroadcastAssignment
2121
type AutoAssignment = impl.AutoAssignment
22+
type AcceleratedSEAssignment = impl.AcceleratedSEAssignment
23+
val AcceleratedAssignment = impl.AcceleratedAssignment
2224
}

0 commit comments

Comments
 (0)