Skip to content

Commit 6f063d1

Browse files
committed
Formatting
1 parent 8c67142 commit 6f063d1

File tree

1 file changed

+52
-28
lines changed

1 file changed

+52
-28
lines changed

src/test/scala/com/massivedatascience/clusterer/IntegrationTestSuite.scala

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package com.massivedatascience.clusterer
1919

20-
import com.massivedatascience.clusterer.TestingUtils._
2120
import com.massivedatascience.clusterer.KMeans.RunConfig
21+
import com.massivedatascience.clusterer.TestingUtils._
2222
import com.massivedatascience.linalg.WeightedVector
2323
import com.massivedatascience.transforms.Embedding
2424
import org.apache.spark.ml.linalg.Vectors
@@ -29,18 +29,24 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
2929
private val seedRng = new scala.util.Random(42L)
3030

3131
/** Produce strictly positive 2D data suitable for KL / Generalized-I. */
32-
private def positive2D(n: Int, clusters: Int): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
32+
private def positive2D(
33+
n: Int,
34+
clusters: Int
35+
): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
3336
val eps = 1e-6
3437
sc.parallelize((0 until n).map { i =>
3538
val base = (i % clusters) * 5.0
36-
val x = math.exp(seedRng.nextGaussian() + base) + eps
37-
val y = math.exp(seedRng.nextGaussian() + base) + eps
39+
val x = math.exp(seedRng.nextGaussian() + base) + eps
40+
val y = math.exp(seedRng.nextGaussian() + base) + eps
3841
Vectors.dense(x, y)
3942
})
4043
}
4144

4245
/** Real-valued 2D Gaussian-ish data for Euclidean. */
43-
private def gaussian2D(n: Int, clusters: Int): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
46+
private def gaussian2D(
47+
n: Int,
48+
clusters: Int
49+
): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
4450
sc.parallelize((0 until n).map { i =>
4551
Vectors.dense(
4652
seedRng.nextGaussian() + (i % clusters) * 5,
@@ -54,9 +60,9 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
5460
val dataPositive = positive2D(n = 100, clusters = 3)
5561

5662
val distances = Seq(
57-
BregmanPointOps.EUCLIDEAN -> dataEuclidean,
63+
BregmanPointOps.EUCLIDEAN -> dataEuclidean,
5864
BregmanPointOps.RELATIVE_ENTROPY -> dataPositive, // KL on positive data
59-
BregmanPointOps.GENERALIZED_I -> dataPositive // Generalized I on positive data
65+
BregmanPointOps.GENERALIZED_I -> dataPositive // Generalized I on positive data
6066
)
6167

6268
distances.foreach { case (distanceFunction, data) =>
@@ -68,14 +74,19 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
6874
)
6975

7076
assert(model.centers.length <= 3, s"Too many centers for $distanceFunction")
77+
7178
val cost = model.computeCost(data)
7279
assert(java.lang.Double.isFinite(cost), s"Non-finite cost for $distanceFunction")
7380
assert(cost >= 0.0, s"Negative cost for $distanceFunction: $cost")
7481

7582
val predictions = model.predict(data).collect()
76-
assert(predictions.forall(p => p >= 0 && p < model.centers.length), s"Invalid prediction for $distanceFunction")
83+
assert(
84+
predictions.forall(p => p >= 0 && p < model.centers.length),
85+
s"Invalid prediction for $distanceFunction"
86+
)
7787

78-
val clusterCounts = predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
88+
val clusterCounts =
89+
predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
7990
assert(clusterCounts.size <= 3, s"Too many clusters observed for $distanceFunction")
8091
assert(clusterCounts.values.forall(_ > 0), s"Empty cluster detected for $distanceFunction")
8192
}
@@ -84,7 +95,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
8495
test("embedding pipeline integration with time series") {
8596
val timeSeries = sc.parallelize((0 until 50).map { i =>
8697
val pattern = i % 3
87-
val values = (0 until 16).map { t =>
98+
val values = (0 until 16).map { t =>
8899
pattern match {
89100
case 0 => math.sin(t * 0.5) + seedRng.nextGaussian() * 0.1
90101
case 1 => math.cos(t * 0.3) + seedRng.nextGaussian() * 0.1
@@ -148,8 +159,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
148159
val cost = model.computeCostWeighted(data)
149160
assert(cost >= 0.0 && java.lang.Double.isFinite(cost))
150161
} catch {
151-
case e: IllegalArgumentException
152-
if e.getMessage.contains("requires at least one valid center") =>
162+
case e: IllegalArgumentException if e.getMessage.contains("requires at least one valid center") =>
153163
succeed
154164
case e: IllegalArgumentException if e.getMessage.contains("requirement failed") =>
155165
succeed
@@ -202,7 +212,10 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
202212
val minCost = costs.min
203213
val maxCost = costs.max
204214
val eps = 1e-9
205-
assert((maxCost + eps) / (minCost + eps) <= 100.0, "Clustering implementations produce very different costs")
215+
assert(
216+
(maxCost + eps) / (minCost + eps) <= 100.0,
217+
"Clustering implementations produce very different costs"
218+
)
206219
}
207220

208221
test("large dataset performance test (log only, no timing assert)") {
@@ -235,23 +248,30 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
235248
info(s"KMeans training duration: ${duration}ms")
236249

237250
assert(model.centers.length <= numClusters)
251+
238252
val predictions = model.predict(data).collect()
239253
assert(predictions.forall(p => p >= 0 && p < model.centers.length))
240254

241255
val trueLabels = (0 until numPoints).map(_ % numClusters).toArray
242256

243-
val clusterToTrueLabel = predictions.zipWithIndex
244-
.groupBy(_._1)
245-
.map { case (pred, pairs) =>
246-
val labelCounts = pairs.map { case (_, idx) => trueLabels(idx) }.groupBy(identity).map { case (k, v) => k -> v.length }
247-
val majority = labelCounts.maxBy(_._2)._1
248-
pred -> majority
249-
}
257+
val clusterToTrueLabel =
258+
predictions.zipWithIndex
259+
.groupBy(_._1)
260+
.map { case (pred, pairs) =>
261+
val labelCounts =
262+
pairs
263+
.map { case (_, idx) => trueLabels(idx) }
264+
.groupBy(identity)
265+
.map { case (k, v) => k -> v.length }
266+
val majority = labelCounts.maxBy(_._2)._1
267+
pred -> majority
268+
}
250269

251-
val correctAssignments = predictions.zipWithIndex.count { case (prediction, index) =>
252-
clusterToTrueLabel.get(prediction).contains(trueLabels(index))
253-
}
254-
val accuracy = correctAssignments.toDouble / numPoints
270+
val correctAssignments =
271+
predictions.zipWithIndex.count { case (prediction, index) =>
272+
clusterToTrueLabel.get(prediction).contains(trueLabels(index))
273+
}
274+
val accuracy = correctAssignments.toDouble / numPoints
255275
assert(accuracy > 0.5, s"Poor clustering accuracy: $accuracy")
256276
}
257277

@@ -355,14 +375,15 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
355375
val normalData = sc.parallelize(normalPoints)
356376
val normalPredictions = model.predict(normalData).collect()
357377

358-
val clusterCounts = normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
378+
val clusterCounts =
379+
normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
359380
assert(clusterCounts.size <= 3)
360381
// Loosened threshold to reduce flakiness while still catching degenerate solutions
361382
assert(clusterCounts.values.forall(_ >= 3))
362383
}
363384

364385
test("reproducibility with fixed seeds (seeded data + algorithm)") {
365-
val rng = new scala.util.Random(1234L)
386+
val rng = new scala.util.Random(1234L)
366387
val data = sc.parallelize((0 until 50).map { _ =>
367388
WeightedVector(
368389
Vectors.dense(
@@ -397,8 +418,11 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
397418
val cost1 = model1.computeCostWeighted(data)
398419
val cost2 = model2.computeCostWeighted(data)
399420

400-
val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
421+
val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
401422
val relDiff = math.abs(cost1 - cost2) / math.max(1e-9, math.abs(cost1))
402-
assert(relDiff <= tol, s"Results differ more than ${tol * 100}%% with same seed: $cost1 vs $cost2")
423+
assert(
424+
relDiff <= tol,
425+
f"Results differ more than ${tol * 100}%.0f%% with same seed: $cost1%.6f vs $cost2%.6f"
426+
)
403427
}
404428
}

0 commit comments

Comments
 (0)