Skip to content

Commit 1c09b8e

Browse files
derrickburnsclaude
andcommitted
feat: Major release with new algorithms and performance improvements
New algorithms: - DPMeans: Bayesian nonparametric clustering with automatic k selection - CoClustering: ML Estimator/Model pattern for simultaneous row/column clustering - SphericalKernel: Cosine similarity support for text/embedding clustering Performance improvements: - ElkanLloydsIterator: Triangle inequality acceleration for SE (10-50x speedup) - AcceleratedSEAssignment: Center-distance pruning for single iterations - AdaptiveBroadcastAssignment: Memory-aware broadcast chunk sizing - Vectorized BLAS: Native nrm2, squaredNorm, asum, normalize operations Architecture: - BregmanFunction: Unified trait as single source of truth for divergences - BregmanFunctionAdapter: Bridges to both RDD and DataFrame APIs Bug fixes: - BLAS.doMax comparison operator (was computing minimum) - Division by zero guards in Strategies.scala and CoClusteringInitializer - build.sbt javac version format ("17.0" -> "17") Documentation: - Comprehensive Scaladoc for all 5 estimators - SphericalKMeansExample with executable assertions - ROADMAP.md tracking planned improvements Tests: 942 tests passing (added ~200 new tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 1a7bc0f commit 1c09b8e

36 files changed

+6381
-132
lines changed

CLAUDE.md

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,42 @@
77
> **Versions:** Scala **2.13** (primary) / 2.12, Spark **4.0.x / 3.5.x / 3.4.x**
88
> - **Spark 4.0.x**: Scala 2.13 only (2.12 dropped in Spark 4.0)
99
> - **Spark 3.x**: Both Scala 2.13 and 2.12 supported
10-
> **Math:** Bregman family — divergences include `squaredEuclidean`, `kl`, `itakuraSaito`, `l1`, `generalizedI`, `logistic`.
10+
> **Math:** Bregman family — divergences include `squaredEuclidean`, `kl`, `itakuraSaito`, `l1`, `generalizedI`, `logistic`, `spherical`/`cosine`.
1111
> **Variants:** Bisecting, X-Means, Soft/Fuzzy, Streaming, K-Medians, K-Medoids.
1212
> **Determinism + persistence** are non-negotiable; RDD API is **archived** (reference only).
13+
> **Roadmap:** See `ROADMAP.md` for planned improvements and technical debt.
1314
1415
---
1516

1617
## 0) Operating Principles (do these every time)
1718

18-
1. **Prefer the DataFrame/ML API.** Code and examples use Estimator/Model patterns and Params from this codebase.
19-
2. **No silent API breaks.** If you touch params, model JSON, or persistence schemas, include migration/round-trip tests.
20-
3. **Mathematical fidelity first.** Correct Bregman formulations beat micro-perf. Perf changes must not alter semantics.
21-
4. **Determinism matters.** Same seed ⇒ identical results. Avoid nondeterministic ops in core loops.
19+
1. **Prefer the DataFrame/ML API.** Code and examples use Estimator/Model patterns and Params from this codebase.
20+
2. **No silent API breaks.** If you touch params, model JSON, or persistence schemas, include migration/round-trip tests.
21+
3. **Mathematical fidelity first.** Correct Bregman formulations beat micro-perf. Perf changes must not alter semantics.
22+
4. **Determinism matters.** Same seed ⇒ identical results. Avoid nondeterministic ops in core loops.
2223
5. **Tight PRs.** Small, test-backed, CI-friendly. No speculative abstractions.
24+
6. **Maintain the roadmap.** When making changes, update `ROADMAP.md` to reflect completed work, new issues discovered, or priority changes.
25+
26+
---
27+
28+
## 0.1) Roadmap Maintenance
29+
30+
**IMPORTANT:** The file `ROADMAP.md` contains the project's technical roadmap, including:
31+
- Bug fixes (completed and pending)
32+
- Architecture improvements
33+
- Algorithm additions
34+
- Performance improvements
35+
- Documentation needs
36+
37+
**Claude must:**
38+
1. **Inspect `ROADMAP.md`** at the start of significant work to understand current priorities and context.
39+
2. **Update `ROADMAP.md`** when:
40+
- Completing a bug fix → mark as ✅ FIXED with date
41+
- Discovering a new bug → add to Bug Fixes section with priority
42+
- Completing a feature → move to Completed Items section
43+
- Identifying technical debt → add to appropriate section
44+
- Making architectural decisions → add to Decision Log
45+
3. **Reference roadmap items** in commit messages and PR descriptions where applicable.
2346

2447
---
2548

@@ -31,7 +54,7 @@
3154
- **Spark 3.x**: Both Scala 2.13 and 2.12 supported
3255
- **Scala:** 2.13.x primary (keep code Scala-3-friendly where feasible).
3356
- **Java:** 17.
34-
- **Divergences:** `squaredEuclidean | kl | itakuraSaito | l1 | generalizedI | logistic`.
57+
- **Divergences:** `squaredEuclidean | kl | itakuraSaito | l1 | generalizedI | logistic | spherical | cosine`.
3558
- **Assignment strategies:** `auto | crossJoin (SE fast path) | broadcastUDF (general Bregman)`.
3659
- **Input transforms:** `none | log1p | epsilonShift(shiftValue)`; ensure domain validity for KL/IS.
3760
- **Persistence:** Models round-trip across Spark 3.4↔3.5↔4.0, Scala 2.12↔2.13.

README.md

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ This project generalizes K-Means to multiple Bregman divergences and advanced va
2121

2222
## What's in here
2323

24-
- Multiple divergences: Squared Euclidean, KL, Itakura–Saito, L1/Manhattan (K-Medians), Generalized-I, Logistic-loss
24+
- Multiple divergences: Squared Euclidean, KL, Itakura–Saito, L1/Manhattan (K-Medians), Generalized-I, Logistic-loss, Spherical/Cosine
2525
- Variants: Bisecting, X-Means (BIC/AIC), Soft K-Means, Structured-Streaming K-Means, K-Medoids (PAM/CLARA)
2626
- Scale: Tested on tens of millions of points in 700+ dimensions
2727
- Tooling: Scala 2.13 (primary) / 2.12, Spark 4.0.x / 3.5.x / 3.4.x
@@ -47,7 +47,7 @@ val df = spark.createDataFrame(Seq(
4747

4848
val gkm = new GeneralizedKMeans()
4949
.setK(2)
50-
.setDivergence("kl") // "squaredEuclidean", "itakuraSaito", "l1", "generalizedI", "logistic"
50+
.setDivergence("kl") // "squaredEuclidean", "itakuraSaito", "l1", "generalizedI", "logistic", "spherical"
5151
.setAssignmentStrategy("auto") // "auto" | "crossJoin" (SE fast path) | "broadcastUDF" (general Bregman)
5252
.setMaxIter(20)
5353

@@ -69,7 +69,7 @@ Our comprehensive CI pipeline ensures quality across multiple dimensions:
6969
| **Lint & Style** | Scalastyle compliance, code formatting | Part of main CI |
7070
| **Build Matrix** | Scala 2.12.18 & 2.13.14 × Spark 3.4.3 / 3.5.1 / 4.0.1 | [![CI](https://github.com/derrickburns/generalized-kmeans-clustering/actions/workflows/ci.yml/badge.svg)](https://github.com/derrickburns/generalized-kmeans-clustering/actions/workflows/ci.yml) |
7171
| **Test Matrix** | 730 tests across all Scala/Spark combinations<br/>• 62 kernel accuracy tests (divergence formulas, gradients, inverse gradients)<br/>• 19 Lloyd's iterator tests (core k-means loop)<br/>• Determinism, edge cases, numerical stability | Part of main CI |
72-
| **Executable Documentation** | All examples run with assertions that verify correctness ([ExamplesSuite](src/test/scala/examples/ExamplesSuite.scala)):<br/>• [BisectingExample](src/main/scala/examples/BisectingExample.scala) - validates cluster count<br/>• [SoftKMeansExample](src/main/scala/examples/SoftKMeansExample.scala) - validates probability columns<br/>• [XMeansExample](src/main/scala/examples/XMeansExample.scala) - validates automatic k selection<br/>• [PersistenceRoundTrip](src/main/scala/examples/PersistenceRoundTrip.scala) - validates save/load with center accuracy<br/>• [PersistenceRoundTripKMedoids](src/main/scala/examples/PersistenceRoundTripKMedoids.scala) - validates medoid preservation | Part of main CI |
72+
| **Executable Documentation** | All examples run with assertions that verify correctness ([ExamplesSuite](src/test/scala/examples/ExamplesSuite.scala)):<br/>• [BisectingExample](src/main/scala/examples/BisectingExample.scala) - validates cluster count<br/>• [SoftKMeansExample](src/main/scala/examples/SoftKMeansExample.scala) - validates probability columns<br/>• [XMeansExample](src/main/scala/examples/XMeansExample.scala) - validates automatic k selection<br/>• [SphericalKMeansExample](src/main/scala/examples/SphericalKMeansExample.scala) - validates cosine similarity clustering<br/>• [PersistenceRoundTrip](src/main/scala/examples/PersistenceRoundTrip.scala) - validates save/load with center accuracy<br/>• [PersistenceRoundTripKMedoids](src/main/scala/examples/PersistenceRoundTripKMedoids.scala) - validates medoid preservation | Part of main CI |
7373
| **Cross-version Persistence** | Models save/load across Scala 2.12↔2.13 and Spark 3.4↔3.5↔4.0 | Part of main CI |
7474
| **Performance Sanity** | Basic performance regression check (30s budget) | Part of main CI |
7575
| **Python Smoke Test** | PySpark wrapper with both SE and non-SE divergences | Part of main CI |
@@ -92,16 +92,17 @@ Truth-linked to code, tests, and examples for full transparency:
9292
| **Streaming K-Means** || [Code](src/main/scala/com/massivedatascience/clusterer/ml/StreamingKMeans.scala) | [Tests](src/test/scala/com/massivedatascience/clusterer/StreamingKMeansSuite.scala) | [Persistence](src/main/scala/examples/PersistenceRoundTripStreamingKMeans.scala) | Real-time with exponential forgetting |
9393
| **K-Medoids** || [Code](src/main/scala/com/massivedatascience/clusterer/ml/KMedoids.scala) | [Tests](src/test/scala/com/massivedatascience/clusterer/KMedoidsSuite.scala) | [Persistence](src/main/scala/examples/PersistenceRoundTripKMedoids.scala) | Outlier-robust, custom distances |
9494
| **K-Medians** || [Code](src/main/scala/com/massivedatascience/clusterer/ml/df/L1Kernel.scala) | [Tests](src/test/scala/com/massivedatascience/clusterer/ml/GeneralizedKMeansSuite.scala) | [Example](src/main/scala/examples/BisectingExample.scala) | L1/Manhattan robustness |
95+
| **Spherical K-Means** || [Code](src/main/scala/com/massivedatascience/clusterer/ml/df/BregmanKernel.scala) | [Tests](src/test/scala/com/massivedatascience/clusterer/ml/df/BregmanKernelAccuracySuite.scala) | [Example](src/main/scala/examples/SphericalKMeansExample.scala) | Text/embedding clustering (cosine) |
9596
| **Coreset K-Means** || [Code](src/main/scala/com/massivedatascience/clusterer/ml/CoresetKMeans.scala) | [Tests](src/test/scala/com/massivedatascience/clusterer/ml/CoresetKMeansSuite.scala) | [Persistence](src/main/scala/examples/PersistenceRoundTripCoresetKMeans.scala) | Large-scale approximation (10-100x speedup) |
9697
| Constrained K-Means | ⚠️ RDD only | [Code](src/main/scala/com/massivedatascience/clusterer) | Legacy || Balance/capacity constraints |
9798
| Mini-Batch K-Means | ⚠️ RDD only | [Code](src/main/scala/com/massivedatascience/clusterer) | Legacy || Massive datasets via sampling |
9899

99-
**Divergences Available**: Squared Euclidean, KL, Itakura-Saito, L1/Manhattan, Generalized-I, Logistic Loss
100+
**Divergences Available**: Squared Euclidean, KL, Itakura-Saito, L1/Manhattan, Generalized-I, Logistic Loss, Spherical/Cosine
100101

101102
All DataFrame API algorithms include:
102103
- ✅ Model persistence (save/load across Spark 3.4↔3.5↔4.0, Scala 2.12↔2.13)
103104
- ✅ Comprehensive test coverage (740 tests, 100% passing)
104-
- ✅ Executable documentation with assertions (8 examples validate correctness in CI)
105+
- ✅ Executable documentation with assertions (9 examples validate correctness in CI)
105106
- ✅ Deterministic behavior (same seed → identical results)
106107
- ✅ CI validation on every commit
107108

@@ -204,6 +205,7 @@ Note: Cluster centers are learned in the transformed space. If you need original
204205
|------------|-------------------|-------------|
205206
| **squaredEuclidean** | Any finite values (x ∈ ℝ) | None needed |
206207
| **l1** / **manhattan** | Any finite values (x ∈ ℝ) | None needed |
208+
| **spherical** / **cosine** | Non-zero vectors (‖x‖ > 0) | None needed (auto-normalized) |
207209
| **kl** | Strictly positive (x > 0) | Use `log1p` or `epsilonShift` transform |
208210
| **itakuraSaito** | Strictly positive (x > 0) | Use `log1p` or `epsilonShift` transform |
209211
| **generalizedI** | Non-negative (x ≥ 0) | Take absolute values or shift data |
@@ -296,6 +298,50 @@ Example:
296298

297299
---
298300

301+
## Spherical K-Means (Cosine Similarity)
302+
303+
Spherical K-Means clusters data on the unit hypersphere using cosine similarity. This is ideal for:
304+
- **Text/document clustering** (TF-IDF vectors, word embeddings)
305+
- **Image feature clustering** (CNN embeddings)
306+
- **Recommendation systems** (user/item embeddings)
307+
- **Any high-dimensional sparse data** where direction matters more than magnitude
308+
309+
**How it works:**
310+
1. All vectors are automatically L2-normalized to unit length
311+
2. Distance: `D(x, μ) = 1 - cos(x, μ) = 1 - (x · μ)` for unit vectors
312+
3. Centers are computed as normalized mean of assigned points
313+
314+
**Example:**
315+
316+
```scala
317+
import com.massivedatascience.clusterer.ml.GeneralizedKMeans
318+
319+
// Example: Clustering text embeddings
320+
val embeddings = spark.createDataFrame(Seq(
321+
Tuple1(Vectors.dense(0.8, 0.6, 0.0)), // Document about topic A
322+
Tuple1(Vectors.dense(0.9, 0.5, 0.1)), // Also topic A (similar direction)
323+
Tuple1(Vectors.dense(0.1, 0.2, 0.95)), // Document about topic B
324+
Tuple1(Vectors.dense(0.0, 0.3, 0.9)) // Also topic B
325+
)).toDF("features")
326+
327+
val sphericalKMeans = new GeneralizedKMeans()
328+
.setK(2)
329+
.setDivergence("spherical") // or "cosine"
330+
.setMaxIter(20)
331+
332+
val model = sphericalKMeans.fit(embeddings)
333+
val predictions = model.transform(embeddings)
334+
predictions.show()
335+
```
336+
337+
**Key properties:**
338+
- Distance range: `[0, 2]` (0 = identical direction, 2 = opposite direction)
339+
- Equivalent to squared Euclidean on normalized data: `‖x - μ‖² = 2(1 - x·μ)`
340+
- No domain restrictions except non-zero vectors
341+
- Available in all estimators: `GeneralizedKMeans`, `BisectingKMeans`, `SoftKMeans`, `StreamingKMeans`
342+
343+
---
344+
299345
## Bisecting K-Means — efficiency note
300346

301347
The driver maintains a cluster_id column. For each split:
@@ -404,6 +450,8 @@ For brevity in this chat, I’m not duplicating it again, but in your repo, plac
404450
- Installation / Versions
405451
- Scaling & Assignment Strategy
406452
- Input Transforms & Interpretation
453+
- Domain Requirements & Validation
454+
- Spherical K-Means (Cosine Similarity)
407455
- Bisecting K-Means — efficiency note
408456
- Structured Streaming K-Means
409457
- Persistence (Spark ML)

0 commit comments

Comments
 (0)