1717
1818package com .massivedatascience .clusterer
1919
20- import com .massivedatascience .clusterer .TestingUtils ._
2120import com .massivedatascience .clusterer .KMeans .RunConfig
21+ import com .massivedatascience .clusterer .TestingUtils ._
2222import com .massivedatascience .linalg .WeightedVector
2323import com .massivedatascience .transforms .Embedding
2424import org .apache .spark .ml .linalg .Vectors
@@ -29,18 +29,24 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
2929 private val seedRng = new scala.util.Random (42L )
3030
3131 /** Produce strictly positive 2D data suitable for KL / Generalized-I. */
32- private def positive2D (n : Int , clusters : Int ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
32+ private def positive2D (
33+ n : Int ,
34+ clusters : Int
35+ ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
3336 val eps = 1e-6
3437 sc.parallelize((0 until n).map { i =>
3538 val base = (i % clusters) * 5.0
36- val x = math.exp(seedRng.nextGaussian() + base) + eps
37- val y = math.exp(seedRng.nextGaussian() + base) + eps
39+ val x = math.exp(seedRng.nextGaussian() + base) + eps
40+ val y = math.exp(seedRng.nextGaussian() + base) + eps
3841 Vectors .dense(x, y)
3942 })
4043 }
4144
4245 /** Real-valued 2D Gaussian-ish data for Euclidean. */
43- private def gaussian2D (n : Int , clusters : Int ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
46+ private def gaussian2D (
47+ n : Int ,
48+ clusters : Int
49+ ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
4450 sc.parallelize((0 until n).map { i =>
4551 Vectors .dense(
4652 seedRng.nextGaussian() + (i % clusters) * 5 ,
@@ -54,9 +60,9 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
5460 val dataPositive = positive2D(n = 100 , clusters = 3 )
5561
5662 val distances = Seq (
57- BregmanPointOps .EUCLIDEAN -> dataEuclidean,
63+ BregmanPointOps .EUCLIDEAN -> dataEuclidean,
5864 BregmanPointOps .RELATIVE_ENTROPY -> dataPositive, // KL on positive data
59- BregmanPointOps .GENERALIZED_I -> dataPositive // Generalized I on positive data
65+ BregmanPointOps .GENERALIZED_I -> dataPositive // Generalized I on positive data
6066 )
6167
6268 distances.foreach { case (distanceFunction, data) =>
@@ -68,14 +74,19 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
6874 )
6975
7076 assert(model.centers.length <= 3 , s " Too many centers for $distanceFunction" )
77+
7178 val cost = model.computeCost(data)
7279 assert(java.lang.Double .isFinite(cost), s " Non-finite cost for $distanceFunction" )
7380 assert(cost >= 0.0 , s " Negative cost for $distanceFunction: $cost" )
7481
7582 val predictions = model.predict(data).collect()
76- assert(predictions.forall(p => p >= 0 && p < model.centers.length), s " Invalid prediction for $distanceFunction" )
83+ assert(
84+ predictions.forall(p => p >= 0 && p < model.centers.length),
85+ s " Invalid prediction for $distanceFunction"
86+ )
7787
78- val clusterCounts = predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
88+ val clusterCounts =
89+ predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
7990 assert(clusterCounts.size <= 3 , s " Too many clusters observed for $distanceFunction" )
8091 assert(clusterCounts.values.forall(_ > 0 ), s " Empty cluster detected for $distanceFunction" )
8192 }
@@ -84,7 +95,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
8495 test(" embedding pipeline integration with time series" ) {
8596 val timeSeries = sc.parallelize((0 until 50 ).map { i =>
8697 val pattern = i % 3
87- val values = (0 until 16 ).map { t =>
98+ val values = (0 until 16 ).map { t =>
8899 pattern match {
89100 case 0 => math.sin(t * 0.5 ) + seedRng.nextGaussian() * 0.1
90101 case 1 => math.cos(t * 0.3 ) + seedRng.nextGaussian() * 0.1
@@ -148,8 +159,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
148159 val cost = model.computeCostWeighted(data)
149160 assert(cost >= 0.0 && java.lang.Double .isFinite(cost))
150161 } catch {
151- case e : IllegalArgumentException
152- if e.getMessage.contains(" requires at least one valid center" ) =>
162+ case e : IllegalArgumentException if e.getMessage.contains(" requires at least one valid center" ) =>
153163 succeed
154164 case e : IllegalArgumentException if e.getMessage.contains(" requirement failed" ) =>
155165 succeed
@@ -202,7 +212,10 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
202212 val minCost = costs.min
203213 val maxCost = costs.max
204214 val eps = 1e-9
205- assert((maxCost + eps) / (minCost + eps) <= 100.0 , " Clustering implementations produce very different costs" )
215+ assert(
216+ (maxCost + eps) / (minCost + eps) <= 100.0 ,
217+ " Clustering implementations produce very different costs"
218+ )
206219 }
207220
208221 test(" large dataset performance test (log only, no timing assert)" ) {
@@ -235,23 +248,30 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
235248 info(s " KMeans training duration: ${duration}ms " )
236249
237250 assert(model.centers.length <= numClusters)
251+
238252 val predictions = model.predict(data).collect()
239253 assert(predictions.forall(p => p >= 0 && p < model.centers.length))
240254
241255 val trueLabels = (0 until numPoints).map(_ % numClusters).toArray
242256
243- val clusterToTrueLabel = predictions.zipWithIndex
244- .groupBy(_._1)
245- .map { case (pred, pairs) =>
246- val labelCounts = pairs.map { case (_, idx) => trueLabels(idx) }.groupBy(identity).map { case (k, v) => k -> v.length }
247- val majority = labelCounts.maxBy(_._2)._1
248- pred -> majority
249- }
257+ val clusterToTrueLabel =
258+ predictions.zipWithIndex
259+ .groupBy(_._1)
260+ .map { case (pred, pairs) =>
261+ val labelCounts =
262+ pairs
263+ .map { case (_, idx) => trueLabels(idx) }
264+ .groupBy(identity)
265+ .map { case (k, v) => k -> v.length }
266+ val majority = labelCounts.maxBy(_._2)._1
267+ pred -> majority
268+ }
250269
251- val correctAssignments = predictions.zipWithIndex.count { case (prediction, index) =>
252- clusterToTrueLabel.get(prediction).contains(trueLabels(index))
253- }
254- val accuracy = correctAssignments.toDouble / numPoints
270+ val correctAssignments =
271+ predictions.zipWithIndex.count { case (prediction, index) =>
272+ clusterToTrueLabel.get(prediction).contains(trueLabels(index))
273+ }
274+ val accuracy = correctAssignments.toDouble / numPoints
255275 assert(accuracy > 0.5 , s " Poor clustering accuracy: $accuracy" )
256276 }
257277
@@ -355,14 +375,15 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
355375 val normalData = sc.parallelize(normalPoints)
356376 val normalPredictions = model.predict(normalData).collect()
357377
358- val clusterCounts = normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
378+ val clusterCounts =
379+ normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
359380 assert(clusterCounts.size <= 3 )
360381 // Loosened threshold to reduce flakiness while still catching degenerate solutions
361382 assert(clusterCounts.values.forall(_ >= 3 ))
362383 }
363384
364385 test(" reproducibility with fixed seeds (seeded data + algorithm)" ) {
365- val rng = new scala.util.Random (1234L )
386+ val rng = new scala.util.Random (1234L )
366387 val data = sc.parallelize((0 until 50 ).map { _ =>
367388 WeightedVector (
368389 Vectors .dense(
@@ -397,8 +418,11 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
397418 val cost1 = model1.computeCostWeighted(data)
398419 val cost2 = model2.computeCostWeighted(data)
399420
400- val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
421+ val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
401422 val relDiff = math.abs(cost1 - cost2) / math.max(1e-9 , math.abs(cost1))
402- assert(relDiff <= tol, s " Results differ more than ${tol * 100 }%% with same seed: $cost1 vs $cost2" )
423+ assert(
424+ relDiff <= tol,
425+ f " Results differ more than ${tol * 100 }%.0f%% with same seed: $cost1%.6f vs $cost2%.6f "
426+ )
403427 }
404428}
0 commit comments