Skip to content

Commit 8d16f5b

Browse files
committed
Revert "Formatting"
This reverts commit 6f063d1.
1 parent 6f063d1 commit 8d16f5b

File tree

1 file changed

+28
-52
lines changed

1 file changed

+28
-52
lines changed

src/test/scala/com/massivedatascience/clusterer/IntegrationTestSuite.scala

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package com.massivedatascience.clusterer
1919

20-
import com.massivedatascience.clusterer.KMeans.RunConfig
2120
import com.massivedatascience.clusterer.TestingUtils._
21+
import com.massivedatascience.clusterer.KMeans.RunConfig
2222
import com.massivedatascience.linalg.WeightedVector
2323
import com.massivedatascience.transforms.Embedding
2424
import org.apache.spark.ml.linalg.Vectors
@@ -29,24 +29,18 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
2929
private val seedRng = new scala.util.Random(42L)
3030

3131
/** Produce strictly positive 2D data suitable for KL / Generalized-I. */
32-
private def positive2D(
33-
n: Int,
34-
clusters: Int
35-
): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
32+
private def positive2D(n: Int, clusters: Int): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
3633
val eps = 1e-6
3734
sc.parallelize((0 until n).map { i =>
3835
val base = (i % clusters) * 5.0
39-
val x = math.exp(seedRng.nextGaussian() + base) + eps
40-
val y = math.exp(seedRng.nextGaussian() + base) + eps
36+
val x = math.exp(seedRng.nextGaussian() + base) + eps
37+
val y = math.exp(seedRng.nextGaussian() + base) + eps
4138
Vectors.dense(x, y)
4239
})
4340
}
4441

4542
/** Real-valued 2D Gaussian-ish data for Euclidean. */
46-
private def gaussian2D(
47-
n: Int,
48-
clusters: Int
49-
): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
43+
private def gaussian2D(n: Int, clusters: Int): org.apache.spark.rdd.RDD[org.apache.spark.ml.linalg.Vector] = {
5044
sc.parallelize((0 until n).map { i =>
5145
Vectors.dense(
5246
seedRng.nextGaussian() + (i % clusters) * 5,
@@ -60,9 +54,9 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
6054
val dataPositive = positive2D(n = 100, clusters = 3)
6155

6256
val distances = Seq(
63-
BregmanPointOps.EUCLIDEAN -> dataEuclidean,
57+
BregmanPointOps.EUCLIDEAN -> dataEuclidean,
6458
BregmanPointOps.RELATIVE_ENTROPY -> dataPositive, // KL on positive data
65-
BregmanPointOps.GENERALIZED_I -> dataPositive // Generalized I on positive data
59+
BregmanPointOps.GENERALIZED_I -> dataPositive // Generalized I on positive data
6660
)
6761

6862
distances.foreach { case (distanceFunction, data) =>
@@ -74,19 +68,14 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
7468
)
7569

7670
assert(model.centers.length <= 3, s"Too many centers for $distanceFunction")
77-
7871
val cost = model.computeCost(data)
7972
assert(java.lang.Double.isFinite(cost), s"Non-finite cost for $distanceFunction")
8073
assert(cost >= 0.0, s"Negative cost for $distanceFunction: $cost")
8174

8275
val predictions = model.predict(data).collect()
83-
assert(
84-
predictions.forall(p => p >= 0 && p < model.centers.length),
85-
s"Invalid prediction for $distanceFunction"
86-
)
76+
assert(predictions.forall(p => p >= 0 && p < model.centers.length), s"Invalid prediction for $distanceFunction")
8777

88-
val clusterCounts =
89-
predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
78+
val clusterCounts = predictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
9079
assert(clusterCounts.size <= 3, s"Too many clusters observed for $distanceFunction")
9180
assert(clusterCounts.values.forall(_ > 0), s"Empty cluster detected for $distanceFunction")
9281
}
@@ -95,7 +84,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
9584
test("embedding pipeline integration with time series") {
9685
val timeSeries = sc.parallelize((0 until 50).map { i =>
9786
val pattern = i % 3
98-
val values = (0 until 16).map { t =>
87+
val values = (0 until 16).map { t =>
9988
pattern match {
10089
case 0 => math.sin(t * 0.5) + seedRng.nextGaussian() * 0.1
10190
case 1 => math.cos(t * 0.3) + seedRng.nextGaussian() * 0.1
@@ -159,7 +148,8 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
159148
val cost = model.computeCostWeighted(data)
160149
assert(cost >= 0.0 && java.lang.Double.isFinite(cost))
161150
} catch {
162-
case e: IllegalArgumentException if e.getMessage.contains("requires at least one valid center") =>
151+
case e: IllegalArgumentException
152+
if e.getMessage.contains("requires at least one valid center") =>
163153
succeed
164154
case e: IllegalArgumentException if e.getMessage.contains("requirement failed") =>
165155
succeed
@@ -212,10 +202,7 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
212202
val minCost = costs.min
213203
val maxCost = costs.max
214204
val eps = 1e-9
215-
assert(
216-
(maxCost + eps) / (minCost + eps) <= 100.0,
217-
"Clustering implementations produce very different costs"
218-
)
205+
assert((maxCost + eps) / (minCost + eps) <= 100.0, "Clustering implementations produce very different costs")
219206
}
220207

221208
test("large dataset performance test (log only, no timing assert)") {
@@ -248,30 +235,23 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
248235
info(s"KMeans training duration: ${duration}ms")
249236

250237
assert(model.centers.length <= numClusters)
251-
252238
val predictions = model.predict(data).collect()
253239
assert(predictions.forall(p => p >= 0 && p < model.centers.length))
254240

255241
val trueLabels = (0 until numPoints).map(_ % numClusters).toArray
256242

257-
val clusterToTrueLabel =
258-
predictions.zipWithIndex
259-
.groupBy(_._1)
260-
.map { case (pred, pairs) =>
261-
val labelCounts =
262-
pairs
263-
.map { case (_, idx) => trueLabels(idx) }
264-
.groupBy(identity)
265-
.map { case (k, v) => k -> v.length }
266-
val majority = labelCounts.maxBy(_._2)._1
267-
pred -> majority
268-
}
269-
270-
val correctAssignments =
271-
predictions.zipWithIndex.count { case (prediction, index) =>
272-
clusterToTrueLabel.get(prediction).contains(trueLabels(index))
243+
val clusterToTrueLabel = predictions.zipWithIndex
244+
.groupBy(_._1)
245+
.map { case (pred, pairs) =>
246+
val labelCounts = pairs.map { case (_, idx) => trueLabels(idx) }.groupBy(identity).map { case (k, v) => k -> v.length }
247+
val majority = labelCounts.maxBy(_._2)._1
248+
pred -> majority
273249
}
274-
val accuracy = correctAssignments.toDouble / numPoints
250+
251+
val correctAssignments = predictions.zipWithIndex.count { case (prediction, index) =>
252+
clusterToTrueLabel.get(prediction).contains(trueLabels(index))
253+
}
254+
val accuracy = correctAssignments.toDouble / numPoints
275255
assert(accuracy > 0.5, s"Poor clustering accuracy: $accuracy")
276256
}
277257

@@ -375,15 +355,14 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
375355
val normalData = sc.parallelize(normalPoints)
376356
val normalPredictions = model.predict(normalData).collect()
377357

378-
val clusterCounts =
379-
normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
358+
val clusterCounts = normalPredictions.groupBy(identity).map { case (k, arr) => k -> arr.length }
380359
assert(clusterCounts.size <= 3)
381360
// Loosened threshold to reduce flakiness while still catching degenerate solutions
382361
assert(clusterCounts.values.forall(_ >= 3))
383362
}
384363

385364
test("reproducibility with fixed seeds (seeded data + algorithm)") {
386-
val rng = new scala.util.Random(1234L)
365+
val rng = new scala.util.Random(1234L)
387366
val data = sc.parallelize((0 until 50).map { _ =>
388367
WeightedVector(
389368
Vectors.dense(
@@ -418,11 +397,8 @@ class IntegrationTestSuite extends AnyFunSuite with LocalClusterSparkContext {
418397
val cost1 = model1.computeCostWeighted(data)
419398
val cost2 = model2.computeCostWeighted(data)
420399

421-
val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
400+
val tol = 0.05 // 5% to absorb non-determinism in reduction order/FP
422401
val relDiff = math.abs(cost1 - cost2) / math.max(1e-9, math.abs(cost1))
423-
assert(
424-
relDiff <= tol,
425-
f"Results differ more than ${tol * 100}%.0f%% with same seed: $cost1%.6f vs $cost2%.6f"
426-
)
402+
assert(relDiff <= tol, s"Results differ more than ${tol * 100}%% with same seed: $cost1 vs $cost2")
427403
}
428404
}

0 commit comments

Comments
 (0)