Skip to content

Commit 8975970

Browse files
committed
Formatting
1 parent 8d16f5b commit 8975970

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

src/test/scala/com/massivedatascience/clusterer/IntegrationTestSuite.scala

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,24 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
2929
private val seedRng = new scala.util.Random(42L)
3030

3131
/** Produce strictly positive 2D data suitable for KL / Generalized-I. */
32-
private def positive2D(n: Int, clusters: Int): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
32+
private def positive2D(
33+
n: Int,
34+
clusters: Int
35+
): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
3336
val eps = 1e-6
3437
sc.parallelize((0 until n).map { i =>
3538
val base = (i % clusters) * 5.0
36-
val x = math.exp(seedRng.nextGaussian() + base) + eps
37-
val y = math.exp(seedRng.nextGaussian() + base) + eps
39+
val x = math.exp(seedRng.nextGaussian() + base) + eps
40+
val y = math.exp(seedRng.nextGaussian() + base) + eps
3841
Vectors.dense(x, y)
3942
})
4043
}
4144

4245
/** Real-valued 2D Gaussian-ish data for Euclidean. */
43-
private def gaussian2D(n: Int, clusters: Int): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
46+
private def gaussian2D(
47+
n: Int,
48+
clusters: Int
49+
): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
4450
sc.parallelize((0 until n).map { i =>
4551
Vectors.dense(
4652
seedRng.nextGaussian() + (i % clusters) * 5,
@@ -54,9 +60,9 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
5460
val dataPositive = positive2D(n = 100, clusters = 3)
5561

5662
val distances = Seq(
57-
BregmanPointOps.EUCLIDEAN -> dataEuclidean,
63+
BregmanPointOps.EUCLIDEAN -> dataEuclidean,
5864
BregmanPointOps.RELATIVE_ENTROPY -> dataPositive, // KL on positive data
59-
BregmanPointOps.GENERALIZED_I -> dataPositive // Generalized I on positive data
65+
BregmanPointOps.GENERALIZED_I -> dataPositive // Generalized I on positive data
6066
)
6167

6268
distances.foreach { case (distanceFunction, data) =>
@@ -73,7 +79,10 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
7379
assert(cost >= 0.0, s"Negative cost for $distanceFunction: $cost")
7480

7581
val predictions = model.predict(data).collect()
76-
assert(predictions.forall(p => p >= 0 && p < model.centers.length), s"Invalid prediction for $distanceFunction")
82+
assert(
83+
predictions.forall(p => p >= 0 && p < model.centers.length),
84+
s"Invalid prediction for $distanceFunction"
85+
)
7786

7887
val clusterCounts = predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
7988
assert(clusterCounts.size <= 3, s"Too many clusters observed for $distanceFunction")
@@ -202,7 +211,10 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
202211
val minCost = costs.min
203212
val maxCost = costs.max
204213
val eps = 1e-9
205-
assert((maxCost + eps) / (minCost + eps) <= 100.0, "Clustering implementations produce very different costs")
214+
assert(
215+
(maxCost + eps) / (minCost + eps) <= 100.0,
216+
"Clustering implementations produce very different costs"
217+
)
206218
}
207219

208220
test("large dataset performance test (log only, no timing assert)") {
@@ -240,13 +252,14 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
240252

241253
val trueLabels = (0 until numPoints).map(_ % numClusters).toArray
242254

243-
val clusterToTrueLabel = predictions.zipWithIndex
244-
.groupBy(_._1)
245-
.map { case (pred, pairs) =>
246-
val labelCounts = pairs.map { case (_, idx) => trueLabels(idx) }.groupBy(identity).map { case (k, v) => k -> v.length }
247-
val majority = labelCounts.maxBy(_._2)._1
248-
pred -> majority
249-
}
255+
val clusterToTrueLabel = predictions.zipWithIndex.groupBy(_._1).map { case (pred, pairs) =>
256+
val labelCounts =
257+
pairs.map { case (_, idx) => trueLabels(idx) }.groupBy(identity).map { case (k, v) =>
258+
k -> v.length
259+
}
260+
val majority = labelCounts.maxBy(_._2)._1
261+
pred -> majority
262+
}
250263

251264
val correctAssignments = predictions.zipWithIndex.count { case (prediction, index) =>
252265
clusterToTrueLabel.get(prediction).contains(trueLabels(index))
@@ -362,7 +375,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
362375
}
363376

364377
test("reproducibility with fixed seeds (seeded data + algorithm)") {
365-
val rng = new scala.util.Random(1234L)
378+
val rng = new scala.util.Random(1234L)
366379
val data = sc.parallelize((0 until 50).map { _ =>
367380
WeightedVector(
368381
Vectors.dense(
@@ -397,8 +410,11 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
397410
val cost1 = model1.computeCostWeighted(data)
398411
val cost2 = model2.computeCostWeighted(data)
399412

400-
val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
413+
val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
401414
val relDiff = math.abs(cost1 - cost2) / math.max(1e-9, math.abs(cost1))
402-
assert(relDiff <= tol, s"Results differ more than ${tol * 100}%% with same seed: $cost1 vs $cost2")
415+
assert(
416+
relDiff <= tol,
417+
s"Results differ more than ${tol * 100}%% with same seed: $cost1 vs $cost2"
418+
)
403419
}
404420
}

0 commit comments

Comments
 (0)