Skip to content

Commit 67bd124

Browse files
committed
[MINOR][TEST] Speed up slow tests in QuantileDiscretizerSuite
## What changes were proposed in this pull request? This should reduce the total runtime of these tests from about 2 minutes to about 25 seconds. ## How was this patch tested? Existing tests Closes apache#24360 from srowen/SpeedQDS. Authored-by: Sean Owen <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 38fc8e2 commit 67bd124

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,30 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
2929
val spark = this.spark
3030
import spark.implicits._
3131

32-
val datasetSize = 100000
32+
val datasetSize = 30000
3333
val numBuckets = 5
34-
val df = sc.parallelize(1 to datasetSize).map(_.toDouble).map(Tuple1.apply).toDF("input")
34+
val df = sc.parallelize(1 to datasetSize).map(_.toDouble).toDF("input")
3535
val discretizer = new QuantileDiscretizer()
3636
.setInputCol("input")
3737
.setOutputCol("result")
3838
.setNumBuckets(numBuckets)
3939
val model = discretizer.fit(df)
4040

41-
testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows =>
42-
val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
43-
val observedNumBuckets = result.select("result").distinct.count
44-
assert(observedNumBuckets === numBuckets,
45-
"Observed number of buckets does not equal expected number of buckets.")
46-
val relativeError = discretizer.getRelativeError
47-
val numGoodBuckets = result.groupBy("result").count
48-
.filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count
49-
assert(numGoodBuckets === numBuckets,
50-
"Bucket sizes are not within expected relative error tolerance.")
41+
testTransformerByGlobalCheckFunc[Double](df, model, "result") { rows =>
42+
val result = rows.map(_.getDouble(0)).toDF("result").cache()
43+
try {
44+
val observedNumBuckets = result.select("result").distinct().count()
45+
assert(observedNumBuckets === numBuckets,
46+
"Observed number of buckets does not equal expected number of buckets.")
47+
val relativeError = discretizer.getRelativeError
48+
val numGoodBuckets = result.groupBy("result").count()
49+
.filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}")
50+
.count()
51+
assert(numGoodBuckets === numBuckets,
52+
"Bucket sizes are not within expected relative error tolerance.")
53+
} finally {
54+
result.unpersist()
55+
}
5156
}
5257
}
5358

@@ -162,10 +167,10 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
162167
val spark = this.spark
163168
import spark.implicits._
164169

165-
val datasetSize = 100000
170+
val datasetSize = 30000
166171
val numBuckets = 5
167-
val data1 = Array.range(1, 100001, 1).map(_.toDouble)
168-
val data2 = Array.range(1, 200000, 2).map(_.toDouble)
172+
val data1 = Array.range(1, datasetSize + 1, 1).map(_.toDouble)
173+
val data2 = Array.range(1, 2 * datasetSize, 2).map(_.toDouble)
169174
val df = data1.zip(data2).toSeq.toDF("input1", "input2")
170175

171176
val discretizer = new QuantileDiscretizer()
@@ -175,20 +180,24 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
175180
val model = discretizer.fit(df)
176181
testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows =>
177182
val result =
178-
rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2")
179-
val relativeError = discretizer.getRelativeError
180-
for (i <- 1 to 2) {
181-
val observedNumBuckets = result.select("result" + i).distinct.count
182-
assert(observedNumBuckets === numBuckets,
183-
"Observed number of buckets does not equal expected number of buckets.")
184-
185-
val numGoodBuckets = result
186-
.groupBy("result" + i)
187-
.count
188-
.filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}")
189-
.count
190-
assert(numGoodBuckets === numBuckets,
191-
"Bucket sizes are not within expected relative error tolerance.")
183+
rows.map(r => (r.getDouble(0), r.getDouble(1))).toDF("result1", "result2").cache()
184+
try {
185+
val relativeError = discretizer.getRelativeError
186+
for (i <- 1 to 2) {
187+
val observedNumBuckets = result.select("result" + i).distinct().count()
188+
assert(observedNumBuckets === numBuckets,
189+
"Observed number of buckets does not equal expected number of buckets.")
190+
191+
val numGoodBuckets = result
192+
.groupBy("result" + i)
193+
.count()
194+
.filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}")
195+
.count()
196+
assert(numGoodBuckets === numBuckets,
197+
"Bucket sizes are not within expected relative error tolerance.")
198+
}
199+
} finally {
200+
result.unpersist()
192201
}
193202
}
194203
}

0 commit comments

Comments
 (0)