1717
1818package com .massivedatascience .clusterer
1919
20- import com .massivedatascience .clusterer .KMeans .RunConfig
2120import com .massivedatascience .clusterer .TestingUtils ._
21+ import com .massivedatascience .clusterer .KMeans .RunConfig
2222import com .massivedatascience .linalg .WeightedVector
2323import com .massivedatascience .transforms .Embedding
2424import org .apache .spark .ml .linalg .Vectors
@@ -29,24 +29,18 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
2929 private val seedRng = new scala.util.Random (42L )
3030
3131 /** Produce strictly positive 2D data suitable for KL / Generalized-I. */
32- private def positive2D (
33- n : Int ,
34- clusters : Int
35- ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
32+ private def positive2D (n : Int , clusters : Int ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
3633 val eps = 1e-6
3734 sc.parallelize((0 until n).map { i =>
3835 val base = (i % clusters) * 5.0
39- val x = math.exp(seedRng.nextGaussian() + base) + eps
40- val y = math.exp(seedRng.nextGaussian() + base) + eps
36+ val x = math.exp(seedRng.nextGaussian() + base) + eps
37+ val y = math.exp(seedRng.nextGaussian() + base) + eps
4138 Vectors .dense(x, y)
4239 })
4340 }
4441
4542 /** Real-valued 2D Gaussian-ish data for Euclidean. */
46- private def gaussian2D (
47- n : Int ,
48- clusters : Int
49- ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
43+ private def gaussian2D (n : Int , clusters : Int ): org.apache.spark.rdd.RDD [org.apache.spark.ml.linalg.Vector ] = {
5044 sc.parallelize((0 until n).map { i =>
5145 Vectors .dense(
5246 seedRng.nextGaussian() + (i % clusters) * 5 ,
@@ -60,9 +54,9 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
6054 val dataPositive = positive2D(n = 100 , clusters = 3 )
6155
6256 val distances = Seq (
63- BregmanPointOps .EUCLIDEAN -> dataEuclidean,
57+ BregmanPointOps .EUCLIDEAN -> dataEuclidean,
6458 BregmanPointOps .RELATIVE_ENTROPY -> dataPositive, // KL on positive data
65- BregmanPointOps .GENERALIZED_I -> dataPositive // Generalized I on positive data
59+ BregmanPointOps .GENERALIZED_I -> dataPositive // Generalized I on positive data
6660 )
6761
6862 distances.foreach { case (distanceFunction, data) =>
@@ -74,19 +68,14 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
7468 )
7569
7670 assert(model.centers.length <= 3 , s " Too many centers for $distanceFunction" )
77-
7871 val cost = model.computeCost(data)
7972 assert(java.lang.Double .isFinite(cost), s " Non-finite cost for $distanceFunction" )
8073 assert(cost >= 0.0 , s " Negative cost for $distanceFunction: $cost" )
8174
8275 val predictions = model.predict(data).collect()
83- assert(
84- predictions.forall(p => p >= 0 && p < model.centers.length),
85- s " Invalid prediction for $distanceFunction"
86- )
76+ assert(predictions.forall(p => p >= 0 && p < model.centers.length), s " Invalid prediction for $distanceFunction" )
8777
88- val clusterCounts =
89- predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
78+ val clusterCounts = predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
9079 assert(clusterCounts.size <= 3 , s " Too many clusters observed for $distanceFunction" )
9180 assert(clusterCounts.values.forall(_ > 0 ), s " Empty cluster detected for $distanceFunction" )
9281 }
@@ -95,7 +84,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
9584 test(" embedding pipeline integration with time series" ) {
9685 val timeSeries = sc.parallelize((0 until 50 ).map { i =>
9786 val pattern = i % 3
98- val values = (0 until 16 ).map { t =>
87+ val values = (0 until 16 ).map { t =>
9988 pattern match {
10089 case 0 => math.sin(t * 0.5 ) + seedRng.nextGaussian() * 0.1
10190 case 1 => math.cos(t * 0.3 ) + seedRng.nextGaussian() * 0.1
@@ -159,7 +148,8 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
159148 val cost = model.computeCostWeighted(data)
160149 assert(cost >= 0.0 && java.lang.Double .isFinite(cost))
161150 } catch {
162- case e : IllegalArgumentException if e.getMessage.contains(" requires at least one valid center" ) =>
151+ case e : IllegalArgumentException
152+ if e.getMessage.contains(" requires at least one valid center" ) =>
163153 succeed
164154 case e : IllegalArgumentException if e.getMessage.contains(" requirement failed" ) =>
165155 succeed
@@ -212,10 +202,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
212202 val minCost = costs.min
213203 val maxCost = costs.max
214204 val eps = 1e-9
215- assert(
216- (maxCost + eps) / (minCost + eps) <= 100.0 ,
217- " Clustering implementations produce very different costs"
218- )
205+ assert((maxCost + eps) / (minCost + eps) <= 100.0 , " Clustering implementations produce very different costs" )
219206 }
220207
221208 test(" large dataset performance test (log only, no timing assert)" ) {
@@ -248,30 +235,23 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
248235 info(s " KMeans training duration: ${duration}ms " )
249236
250237 assert(model.centers.length <= numClusters)
251-
252238 val predictions = model.predict(data).collect()
253239 assert(predictions.forall(p => p >= 0 && p < model.centers.length))
254240
255241 val trueLabels = (0 until numPoints).map(_ % numClusters).toArray
256242
257- val clusterToTrueLabel =
258- predictions.zipWithIndex
259- .groupBy(_._1)
260- .map { case (pred, pairs) =>
261- val labelCounts =
262- pairs
263- .map { case (_, idx) => trueLabels(idx) }
264- .groupBy(identity)
265- .map { case (k, v) => k -> v.length }
266- val majority = labelCounts.maxBy(_._2)._1
267- pred -> majority
268- }
269-
270- val correctAssignments =
271- predictions.zipWithIndex.count { case (prediction, index) =>
272- clusterToTrueLabel.get(prediction).contains(trueLabels(index))
243+ val clusterToTrueLabel = predictions.zipWithIndex
244+ .groupBy(_._1)
245+ .map { case (pred, pairs) =>
246+ val labelCounts = pairs.map { case (_, idx) => trueLabels(idx) }.groupBy(identity).map { case (k, v) => k -> v.length }
247+ val majority = labelCounts.maxBy(_._2)._1
248+ pred -> majority
273249 }
274- val accuracy = correctAssignments.toDouble / numPoints
250+
251+ val correctAssignments = predictions.zipWithIndex.count { case (prediction, index) =>
252+ clusterToTrueLabel.get(prediction).contains(trueLabels(index))
253+ }
254+ val accuracy = correctAssignments.toDouble / numPoints
275255 assert(accuracy > 0.5 , s " Poor clustering accuracy: $accuracy" )
276256 }
277257
@@ -375,15 +355,14 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
375355 val normalData = sc.parallelize(normalPoints)
376356 val normalPredictions = model.predict(normalData).collect()
377357
378- val clusterCounts =
379- normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
358+ val clusterCounts = normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
380359 assert(clusterCounts.size <= 3 )
381360 // Loosened threshold to reduce flakiness while still catching degenerate solutions
382361 assert(clusterCounts.values.forall(_ >= 3 ))
383362 }
384363
385364 test(" reproducibility with fixed seeds (seeded data + algorithm)" ) {
386- val rng = new scala.util.Random (1234L )
365+ val rng = new scala.util.Random (1234L )
387366 val data = sc.parallelize((0 until 50 ).map { _ =>
388367 WeightedVector (
389368 Vectors .dense(
@@ -418,11 +397,8 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
418397 val cost1 = model1.computeCostWeighted(data)
419398 val cost2 = model2.computeCostWeighted(data)
420399
421- val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
400+ val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
422401 val relDiff = math.abs(cost1 - cost2) / math.max(1e-9 , math.abs(cost1))
423- assert(
424- relDiff <= tol,
425- f " Results differ more than ${tol * 100 }%.0f%% with same seed: $cost1%.6f vs $cost2%.6f "
426- )
402+ assert(relDiff <= tol, s " Results differ more than ${tol * 100 }%% with same seed: $cost1 vs $cost2" )
427403 }
428404}
0 commit comments