Skip to content

Commit 52ddf03

Browse files
Ievgen Prokhorenkodongjoon-hyun
authored andcommitted
[SPARK-28440][MLLIB][TEST] Use TestingUtils to compare floating point values
## What changes were proposed in this pull request? Use `org.apache.spark.mllib.util.TestingUtils` object across `MLLIB` component to compare floating point values in tests. ## How was this patch tested? `build/mvn test` - existing tests against updated code. Closes apache#25191 from eugen-prokhorenko/mllib-testingutils-double-comparison. Authored-by: Ievgen Prokhorenko <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 127bc89 commit 52ddf03

File tree

9 files changed

+52
-46
lines changed

9 files changed

+52
-46
lines changed

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.mllib.util.TestingUtils._
2223
import org.apache.spark.rdd.RDD
2324

2425
class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -79,24 +80,24 @@ class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
7980
val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
8081
val strictAccuracy = 2.0 / 7
8182
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
82-
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
83-
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
84-
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
85-
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
86-
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
87-
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
88-
assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
89-
assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
90-
assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
91-
assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
92-
assert(math.abs(metrics.microRecall - microRecallClass) < delta)
93-
assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
94-
assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
95-
assert(math.abs(metrics.recall - macroRecallDoc) < delta)
96-
assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
97-
assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
98-
assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
99-
assert(math.abs(metrics.accuracy - accuracy) < delta)
83+
assert(metrics.precision(0.0) ~== precision0 absTol delta)
84+
assert(metrics.precision(1.0) ~== precision1 absTol delta)
85+
assert(metrics.precision(2.0) ~== precision2 absTol delta)
86+
assert(metrics.recall(0.0) ~== recall0 absTol delta)
87+
assert(metrics.recall(1.0) ~== recall1 absTol delta)
88+
assert(metrics.recall(2.0) ~== recall2 absTol delta)
89+
assert(metrics.f1Measure(0.0) ~== f1measure0 absTol delta)
90+
assert(metrics.f1Measure(1.0) ~== f1measure1 absTol delta)
91+
assert(metrics.f1Measure(2.0) ~== f1measure2 absTol delta)
92+
assert(metrics.microPrecision ~== microPrecisionClass absTol delta)
93+
assert(metrics.microRecall ~== microRecallClass absTol delta)
94+
assert(metrics.microF1Measure ~== microF1MeasureClass absTol delta)
95+
assert(metrics.precision ~== macroPrecisionDoc absTol delta)
96+
assert(metrics.recall ~== macroRecallDoc absTol delta)
97+
assert(metrics.f1Measure ~== macroF1MeasureDoc absTol delta)
98+
assert(metrics.hammingLoss ~== hammingLoss absTol delta)
99+
assert(metrics.subsetAccuracy ~== strictAccuracy absTol delta)
100+
assert(metrics.accuracy ~== accuracy absTol delta)
100101
assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
101102
}
102103
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
1818

1919
import org.apache.spark.SparkFunSuite
2020
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
import org.apache.spark.mllib.util.TestingUtils._
2122

2223
class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
2324

@@ -63,7 +64,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
6364
[1] 23
6465
*/
6566
assert(results1.size === 23)
66-
assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
67+
assert(results1.count(rule => rule.confidence ~= 1.0D absTol 1e-6) == 23)
6768

6869
val results2 = ar
6970
.setMinConfidence(0)
@@ -84,7 +85,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
8485
[1] 23
8586
*/
8687
assert(results2.size === 30)
87-
assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
88+
assert(results2.count(rule => rule.confidence ~= 1.0D absTol 1e-6) == 23)
8889
}
8990
}
9091

mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
1818

1919
import org.apache.spark.SparkFunSuite
2020
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
import org.apache.spark.mllib.util.TestingUtils._
2122
import org.apache.spark.util.Utils
2223

2324
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -172,7 +173,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
172173
.collect()
173174

174175
assert(rules.size === 23)
175-
assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
176+
assert(rules.count(rule => rule.confidence ~= 1.0D absTol 1e-6) == 23)
176177
}
177178

178179
test("FP-Growth using Int type") {

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.mllib.linalg._
2424
import org.apache.spark.mllib.util.MLlibTestSparkContext
25+
import org.apache.spark.mllib.util.TestingUtils._
2526
import org.apache.spark.rdd.RDD
2627

2728
class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -238,7 +239,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
238239

239240
for (i <- 0 until n; j <- i + 1 until n) {
240241
val trueResult = gram(i, j) / scala.math.sqrt(gram(i, i) * gram(j, j))
241-
assert(math.abs(G(i, j) - trueResult) < 1e-6)
242+
assert(G(i, j) ~== trueResult absTol 1e-6)
242243
}
243244
}
244245

mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ package org.apache.spark.mllib.random
2020
import org.apache.commons.math3.special.Gamma
2121

2222
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.mllib.util.TestingUtils._
2324
import org.apache.spark.util.StatCounter
2425

25-
// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
2626
class RandomDataGeneratorSuite extends SparkFunSuite {
2727

2828
def apiChecks(gen: RandomDataGenerator[Double]) {
@@ -61,8 +61,8 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
6161
gen.setSeed(seed.toLong)
6262
val sample = (0 until 100000).map { _ => gen.nextValue()}
6363
val stats = new StatCounter(sample)
64-
assert(math.abs(stats.mean - mean) < epsilon)
65-
assert(math.abs(stats.stdev - stddev) < epsilon)
64+
assert(stats.mean ~== mean absTol epsilon)
65+
assert(stats.stdev ~== stddev absTol epsilon)
6666
}
6767
}
6868

mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@ import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.mllib.linalg.Vector
2424
import org.apache.spark.mllib.rdd.{RandomRDD, RandomRDDPartition}
2525
import org.apache.spark.mllib.util.MLlibTestSparkContext
26+
import org.apache.spark.mllib.util.TestingUtils._
2627
import org.apache.spark.rdd.RDD
2728
import org.apache.spark.util.StatCounter
2829

2930
/*
3031
* Note: avoid including APIs that do not set the seed for the RNG in unit tests
3132
* in order to guarantee deterministic behavior.
32-
*
33-
* TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
3433
*/
3534
class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable {
3635

@@ -43,8 +42,8 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri
4342
val stats = rdd.stats()
4443
assert(expectedSize === stats.count)
4544
assert(expectedNumPartitions === rdd.partitions.size)
46-
assert(math.abs(stats.mean - expectedMean) < epsilon)
47-
assert(math.abs(stats.stdev - expectedStddev) < epsilon)
45+
assert(stats.mean ~== expectedMean absTol epsilon)
46+
assert(stats.stdev ~== expectedStddev absTol epsilon)
4847
}
4948

5049
// assume test RDDs are small
@@ -63,8 +62,8 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri
6362
}}
6463
assert(expectedRows === values.size / expectedColumns)
6564
val stats = new StatCounter(values)
66-
assert(math.abs(stats.mean - expectedMean) < epsilon)
67-
assert(math.abs(stats.stdev - expectedStddev) < epsilon)
65+
assert(stats.mean ~== expectedMean absTol epsilon)
66+
assert(stats.stdev ~== expectedStddev absTol epsilon)
6867
}
6968

7069
test("RandomRDD sizes") {

mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.random.RandomRDDs
2626
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
2727
SpearmanCorrelation}
2828
import org.apache.spark.mllib.util.MLlibTestSparkContext
29+
import org.apache.spark.mllib.util.TestingUtils._
2930

3031
class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
3132

@@ -57,15 +58,15 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
5758
val expected = 0.6546537
5859
val default = Statistics.corr(x, y)
5960
val p1 = Statistics.corr(x, y, "pearson")
60-
assert(approxEqual(expected, default))
61-
assert(approxEqual(expected, p1))
61+
assert(expected ~== default absTol 1e-6)
62+
assert(expected ~== p1 absTol 1e-6)
6263

6364
// numPartitions >= size for input RDDs
6465
for (numParts <- List(xData.size, xData.size * 2)) {
6566
val x1 = sc.parallelize(xData, numParts)
6667
val y1 = sc.parallelize(yData, numParts)
6768
val p2 = Statistics.corr(x1, y1)
68-
assert(approxEqual(expected, p2))
69+
assert(expected ~== p2 absTol 1e-6)
6970
}
7071

7172
// RDD of zero variance
@@ -78,14 +79,14 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
7879
val y = sc.parallelize(yData)
7980
val expected = 0.5
8081
val s1 = Statistics.corr(x, y, "spearman")
81-
assert(approxEqual(expected, s1))
82+
assert(expected ~== s1 absTol 1e-6)
8283

8384
// numPartitions >= size for input RDDs
8485
for (numParts <- List(xData.size, xData.size * 2)) {
8586
val x1 = sc.parallelize(xData, numParts)
8687
val y1 = sc.parallelize(yData, numParts)
8788
val s2 = Statistics.corr(x1, y1, "spearman")
88-
assert(approxEqual(expected, s2))
89+
assert(expected ~== s2 absTol 1e-6)
8990
}
9091

9192
// RDD of zero variance => zero variance in ranks
@@ -141,14 +142,14 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
141142
val a = RandomRDDs.normalRDD(sc, 100000, 10).map(_ + 1000000000.0)
142143
val b = RandomRDDs.normalRDD(sc, 100000, 10).map(_ + 1000000000.0)
143144
val p = Statistics.corr(a, b, method = "pearson")
144-
assert(approxEqual(p, 0.0, 0.01))
145+
assert(p ~== 0.0 absTol 0.01)
145146
}
146147

147148
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
148149
if (v1.isNaN) {
149150
v2.isNaN
150151
} else {
151-
math.abs(v1 - v2) <= threshold
152+
v1 ~== v2 absTol threshold
152153
}
153154
}
154155

mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.commons.math3.distribution.NormalDistribution
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.mllib.util.TestingUtils._
2425

2526
class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
2627
test("kernel density single sample") {
@@ -29,8 +30,8 @@ class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
2930
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
3031
val normal = new NormalDistribution(5.0, 3.0)
3132
val acceptableErr = 1e-6
32-
assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr)
33-
assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr)
33+
assert(densities(0) ~== normal.density(5.0) absTol acceptableErr)
34+
assert(densities(1) ~== normal.density(6.0) absTol acceptableErr)
3435
}
3536

3637
test("kernel density multiple samples") {
@@ -40,9 +41,9 @@ class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
4041
val normal1 = new NormalDistribution(5.0, 3.0)
4142
val normal2 = new NormalDistribution(10.0, 3.0)
4243
val acceptableErr = 1e-6
43-
assert(math.abs(
44-
densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr)
45-
assert(math.abs(
46-
densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr)
44+
assert(
45+
densities(0) ~== ((normal1.density(5.0) + normal2.density(5.0)) / 2) absTol acceptableErr)
46+
assert(
47+
densities(1) ~== ((normal1.density(6.0) + normal2.density(6.0)) / 2) absTol acceptableErr)
4748
}
4849
}

mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222
import org.apache.spark.mllib.linalg.Vectors
2323
import org.apache.spark.mllib.regression.LabeledPoint
2424
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
25+
import org.apache.spark.mllib.util.TestingUtils._
2526
import org.apache.spark.util.StatCounter
2627

2728
object EnsembleTestHelper {
@@ -43,8 +44,8 @@ object EnsembleTestHelper {
4344
values ++= row
4445
}
4546
val stats = new StatCounter(values)
46-
assert(math.abs(stats.mean - expectedMean) < epsilon)
47-
assert(math.abs(stats.stdev - expectedStddev) < epsilon)
47+
assert(stats.mean ~== expectedMean absTol epsilon)
48+
assert(stats.stdev ~== expectedStddev absTol epsilon)
4849
}
4950

5051
def validateClassifier(

0 commit comments

Comments
 (0)