Skip to content

Commit 04a9ffc

Browse files
derrickburnsclaude
andcommitted
feat: add persistence roundtrip examples for all model types
Added executable persistence roundtrip examples with comprehensive assertions: - PersistenceRoundTrip.scala (GeneralizedKMeansModel) - enhanced with assertions - PersistenceRoundTripKMedoids.scala - tests medoid preservation and indices - PersistenceRoundTripSoftKMeans.scala - tests beta, minMembership, probability col - PersistenceRoundTripStreamingKMeans.scala - tests weight preservation and streaming updates All examples follow save/load pattern and verify: - Model parameters roundtrip correctly - Centers preserve correct values - Model-specific state (medoids, weights, etc.) is restored - Predictions work after loading - Special behavior (streaming updates) continues after load Usage for cross-version testing: sbt -Dspark.version=3.4.3 "runMain examples.PersistenceRoundTrip save ./model" sbt -Dspark.version=3.5.1 "runMain examples.PersistenceRoundTrip load ./model" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 7ba783f commit 04a9ffc

File tree

4 files changed

+262
-1
lines changed

4 files changed

+262
-1
lines changed

src/main/scala/examples/PersistenceRoundTrip.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,23 @@ object PersistenceRoundTrip {
3737

3838
case "load" =>
3939
val loaded = com.massivedatascience.clusterer.ml.GeneralizedKMeansModel.load(path)
40+
41+
// Assertions to verify roundtrip correctness
42+
assert(loaded.numClusters == 2, s"Expected k=2, got ${loaded.numClusters}")
43+
assert(loaded.clusterCenters.length == 2, s"Expected 2 centers, got ${loaded.clusterCenters.length}")
44+
assert(loaded.numFeatures == 2, s"Expected dim=2, got ${loaded.numFeatures}")
45+
46+
// Verify predictions work
4047
val preds = loaded.transform(df)
4148
val n = preds.count()
4249
assert(n == 4, s"expected 4 rows after load, got $n")
43-
println(s"Loaded model from $path; predictions=$n")
50+
51+
// Verify center values are reasonable (should be near (0.5, 0.5) and (9.5, 9.5))
52+
val centers = loaded.clusterCenters.sortBy(_.apply(0))
53+
assert(math.abs(centers(0)(0) - 0.5) < 1.0, s"Center 0 x-coord should be near 0.5, got ${centers(0)(0)}")
54+
assert(math.abs(centers(1)(0) - 9.5) < 1.0, s"Center 1 x-coord should be near 9.5, got ${centers(1)(0)}")
55+
56+
println(s"✅ Loaded model from $path; predictions=$n; all assertions passed")
4457
case other =>
4558
sys.error(s"Unknown mode: $other")
4659
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package examples
2+
3+
import org.apache.spark.sql.SparkSession
4+
import org.apache.spark.ml.linalg.Vectors
5+
import com.massivedatascience.clusterer.ml.KMedoids
6+
7+
/**
8+
* Usage:
9+
* sbt -Dspark.version=3.4.3 "runMain examples.PersistenceRoundTripKMedoids save ./tmp_kmedoids_34"
10+
* sbt -Dspark.version=3.5.1 "runMain examples.PersistenceRoundTripKMedoids load ./tmp_kmedoids_34"
11+
*/
12+
object PersistenceRoundTripKMedoids {
13+
def main(args: Array[String]): Unit = {
14+
require(args.length == 2, "args: save|load <path>")
15+
val mode = args(0)
16+
val path = args(1)
17+
18+
val spark = SparkSession.builder().appName("PersistenceRoundTripKMedoids").master("local[*]").getOrCreate()
19+
import spark.implicits._
20+
21+
val df = Seq(
22+
Tuple1(Vectors.dense(0.0, 0.0)),
23+
Tuple1(Vectors.dense(0.1, 0.1)),
24+
Tuple1(Vectors.dense(1.0, 1.0)),
25+
Tuple1(Vectors.dense(9.0, 9.0)),
26+
Tuple1(Vectors.dense(9.1, 9.1)),
27+
Tuple1(Vectors.dense(10.0, 10.0))
28+
).toDF("features")
29+
30+
mode match {
31+
case "save" =>
32+
val kmedoids = new KMedoids()
33+
.setK(2)
34+
.setMaxIter(10)
35+
.setSeed(456)
36+
val model = kmedoids.fit(df)
37+
model.write.overwrite().save(path)
38+
println(s"Saved KMedoids model to $path")
39+
println(s" Medoids: ${model.medoids.mkString(", ")}")
40+
println(s" Medoid indices: ${model.medoidIndices.mkString(", ")}")
41+
42+
case "load" =>
43+
val loaded = com.massivedatascience.clusterer.ml.KMedoidsModel.load(path)
44+
45+
// Assertions to verify roundtrip correctness
46+
assert(loaded.numClusters == 2, s"Expected k=2, got ${loaded.numClusters}")
47+
assert(loaded.medoids.length == 2, s"Expected 2 medoids, got ${loaded.medoids.length}")
48+
assert(loaded.medoidIndices.length == 2, s"Expected 2 medoid indices, got ${loaded.medoidIndices.length}")
49+
assert(loaded.numFeatures == 2, s"Expected dim=2, got ${loaded.numFeatures}")
50+
51+
// Verify predictions work
52+
val preds = loaded.transform(df)
53+
val n = preds.count()
54+
assert(n == 6, s"expected 6 rows after load, got $n")
55+
56+
// Verify medoids are actual data points (one near 0, one near 9-10)
57+
val medoids = loaded.medoids.sortBy(_.apply(0))
58+
assert(medoids(0)(0) < 2.0, s"Medoid 0 should be near cluster at (0,0), got ${medoids(0)}")
59+
assert(medoids(1)(0) > 8.0, s"Medoid 1 should be near cluster at (9-10,9-10), got ${medoids(1)}")
60+
61+
// Verify medoid indices are valid
62+
assert(loaded.medoidIndices.forall(i => i >= 0 && i < 6), s"Medoid indices should be in [0,5], got ${loaded.medoidIndices.mkString(", ")}")
63+
64+
println(s"✅ Loaded KMedoids model from $path; predictions=$n; all assertions passed")
65+
println(s" Medoids: ${loaded.medoids.mkString(", ")}")
66+
println(s" Medoid indices: ${loaded.medoidIndices.mkString(", ")}")
67+
68+
case other =>
69+
sys.error(s"Unknown mode: $other")
70+
}
71+
72+
spark.stop()
73+
}
74+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package examples
2+
3+
import org.apache.spark.sql.SparkSession
4+
import org.apache.spark.ml.linalg.Vectors
5+
import com.massivedatascience.clusterer.ml.SoftKMeans
6+
7+
/**
8+
* Usage:
9+
* sbt -Dspark.version=3.4.3 "runMain examples.PersistenceRoundTripSoftKMeans save ./tmp_soft_34"
10+
* sbt -Dspark.version=3.5.1 "runMain examples.PersistenceRoundTripSoftKMeans load ./tmp_soft_34"
11+
*/
12+
object PersistenceRoundTripSoftKMeans {
13+
def main(args: Array[String]): Unit = {
14+
require(args.length == 2, "args: save|load <path>")
15+
val mode = args(0)
16+
val path = args(1)
17+
18+
val spark = SparkSession.builder().appName("PersistenceRoundTripSoftKMeans").master("local[*]").getOrCreate()
19+
import spark.implicits._
20+
21+
val df = Seq(
22+
Tuple1(Vectors.dense(0.0, 0.0)),
23+
Tuple1(Vectors.dense(0.1, 0.1)),
24+
Tuple1(Vectors.dense(1.0, 1.0)),
25+
Tuple1(Vectors.dense(9.0, 9.0)),
26+
Tuple1(Vectors.dense(9.1, 9.1)),
27+
Tuple1(Vectors.dense(10.0, 10.0))
28+
).toDF("features")
29+
30+
mode match {
31+
case "save" =>
32+
val softKMeans = new SoftKMeans()
33+
.setK(2)
34+
.setDivergence("kullbackLeibler")
35+
.setBeta(2.0)
36+
.setMinMembership(0.01)
37+
.setSeed(789)
38+
val model = softKMeans.fit(df)
39+
model.write.overwrite().save(path)
40+
println(s"Saved SoftKMeans model to $path")
41+
println(s" Centers: ${model.clusterCenters.mkString(", ")}")
42+
println(s" Beta: ${model.betaValue}")
43+
44+
case "load" =>
45+
val loaded = com.massivedatascience.clusterer.ml.SoftKMeansModel.load(path)
46+
47+
// Assertions to verify roundtrip correctness
48+
assert(loaded.numClusters == 2, s"Expected k=2, got ${loaded.numClusters}")
49+
assert(loaded.clusterCenters.length == 2, s"Expected 2 centers, got ${loaded.clusterCenters.length}")
50+
assert(loaded.clusterCenters(0).size == 2, s"Expected dim=2, got ${loaded.clusterCenters(0).size}")
51+
52+
// Verify soft clustering parameters
53+
assert(math.abs(loaded.betaValue - 2.0) < 0.001, s"Expected beta=2.0, got ${loaded.betaValue}")
54+
assert(math.abs(loaded.minMembershipValue - 0.01) < 0.001, s"Expected minMembership=0.01, got ${loaded.minMembershipValue}")
55+
56+
// Verify predictions work and include probability column
57+
val preds = loaded.transform(df)
58+
val n = preds.count()
59+
assert(n == 6, s"expected 6 rows after load, got $n")
60+
61+
// Verify probability column exists
62+
assert(preds.columns.contains("probability"), "Expected 'probability' column in predictions")
63+
64+
// Verify centers are reasonable (one near 0, one near 9-10)
65+
val centers = loaded.clusterCenters.sortBy(_.apply(0))
66+
assert(centers(0)(0) < 2.0, s"Center 0 should be near cluster at (0,0), got ${centers(0)}")
67+
assert(centers(1)(0) > 8.0, s"Center 1 should be near cluster at (9-10,9-10), got ${centers(1)}")
68+
69+
println(s"✅ Loaded SoftKMeans model from $path; predictions=$n; all assertions passed")
70+
println(s" Centers: ${loaded.clusterCenters.mkString(", ")}")
71+
println(s" Beta: ${loaded.betaValue}, MinMembership: ${loaded.minMembershipValue}")
72+
73+
case other =>
74+
sys.error(s"Unknown mode: $other")
75+
}
76+
77+
spark.stop()
78+
}
79+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package examples
2+
3+
import org.apache.spark.sql.SparkSession
4+
import org.apache.spark.ml.linalg.Vectors
5+
import com.massivedatascience.clusterer.ml.StreamingKMeans
6+
7+
/**
8+
* Usage:
9+
* sbt -Dspark.version=3.4.3 "runMain examples.PersistenceRoundTripStreamingKMeans save ./tmp_streaming_34"
10+
* sbt -Dspark.version=3.5.1 "runMain examples.PersistenceRoundTripStreamingKMeans load ./tmp_streaming_34"
11+
*/
12+
object PersistenceRoundTripStreamingKMeans {
13+
def main(args: Array[String]): Unit = {
14+
require(args.length == 2, "args: save|load <path>")
15+
val mode = args(0)
16+
val path = args(1)
17+
18+
val spark = SparkSession.builder().appName("PersistenceRoundTripStreamingKMeans").master("local[*]").getOrCreate()
19+
import spark.implicits._
20+
21+
val df1 = Seq(
22+
Tuple1(Vectors.dense(0.0, 0.0)),
23+
Tuple1(Vectors.dense(0.1, 0.1)),
24+
Tuple1(Vectors.dense(1.0, 1.0))
25+
).toDF("features")
26+
27+
val df2 = Seq(
28+
Tuple1(Vectors.dense(9.0, 9.0)),
29+
Tuple1(Vectors.dense(9.1, 9.1)),
30+
Tuple1(Vectors.dense(10.0, 10.0))
31+
).toDF("features")
32+
33+
mode match {
34+
case "save" =>
35+
val streamingKMeans = new StreamingKMeans()
36+
.setK(2)
37+
.setDivergence("squaredEuclidean")
38+
.setDecayFactor(0.9)
39+
.setSmoothing(1e-9)
40+
.setSeed(42)
41+
42+
// Initialize model with first batch
43+
val model1 = streamingKMeans.fit(df1)
44+
println(s"After batch 1 - Centers: ${model1.clusterCenters.mkString(", ")}")
45+
println(s"After batch 1 - Weights: ${model1.currentWeights.mkString(", ")}")
46+
47+
// Simulate streaming update with second batch
48+
val model2 = model1.update(df2)
49+
println(s"After batch 2 - Centers: ${model2.clusterCenters.mkString(", ")}")
50+
println(s"After batch 2 - Weights: ${model2.currentWeights.mkString(", ")}")
51+
52+
model2.write.overwrite().save(path)
53+
println(s"Saved StreamingKMeans model to $path")
54+
55+
case "load" =>
56+
val loaded = com.massivedatascience.clusterer.ml.StreamingKMeansModel.load(path)
57+
58+
// Assertions to verify roundtrip correctness
59+
assert(loaded.numClusters == 2, s"Expected k=2, got ${loaded.numClusters}")
60+
assert(loaded.clusterCenters.length == 2, s"Expected 2 centers, got ${loaded.clusterCenters.length}")
61+
assert(loaded.numFeatures == 2, s"Expected dim=2, got ${loaded.numFeatures}")
62+
63+
// Verify streaming-specific parameters
64+
assert(loaded.divergenceName == "squaredEuclidean", s"Expected squaredEuclidean divergence, got ${loaded.divergenceName}")
65+
assert(math.abs(loaded.decayFactorValue - 0.9) < 0.001, s"Expected decayFactor=0.9, got ${loaded.decayFactorValue}")
66+
assert(math.abs(loaded.smoothingValue - 1e-9) < 1e-10, s"Expected smoothing=1e-9, got ${loaded.smoothingValue}")
67+
68+
// CRITICAL: Verify cluster weights were restored (essential for streaming!)
69+
val currentWeights = loaded.currentWeights
70+
assert(currentWeights.length == 2, s"Expected 2 cluster weights, got ${currentWeights.length}")
71+
assert(currentWeights.forall(_ > 0), s"Cluster weights should be positive, got ${currentWeights.mkString(", ")}")
72+
println(s"Cluster weights restored: ${currentWeights.mkString(", ")}")
73+
74+
// Verify predictions work
75+
val preds = loaded.transform(df1)
76+
val n = preds.count()
77+
assert(n == 3, s"expected 3 rows after load, got $n")
78+
79+
// Verify we can continue streaming after load
80+
val continued = loaded.update(df2)
81+
assert(continued.clusterCenters.length == 2, "Should be able to continue streaming after load")
82+
println(s"After continued update - Centers: ${continued.clusterCenters.mkString(", ")}")
83+
println(s"After continued update - Weights: ${continued.currentWeights.mkString(", ")}")
84+
85+
println(s"✅ Loaded StreamingKMeans model from $path; predictions=$n; all assertions passed")
86+
println(s" Centers: ${loaded.clusterCenters.mkString(", ")}")
87+
println(s" Weights: ${loaded.currentWeights.mkString(", ")}")
88+
89+
case other =>
90+
sys.error(s"Unknown mode: $other")
91+
}
92+
93+
spark.stop()
94+
}
95+
}

0 commit comments

Comments
 (0)