@@ -29,25 +29,30 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
29
29
val spark = this .spark
30
30
import spark .implicits ._
31
31
32
- val datasetSize = 100000
32
+ val datasetSize = 30000
33
33
val numBuckets = 5
34
- val df = sc.parallelize(1 to datasetSize).map(_.toDouble).map( Tuple1 .apply). toDF(" input" )
34
+ val df = sc.parallelize(1 to datasetSize).map(_.toDouble).toDF(" input" )
35
35
val discretizer = new QuantileDiscretizer ()
36
36
.setInputCol(" input" )
37
37
.setOutputCol(" result" )
38
38
.setNumBuckets(numBuckets)
39
39
val model = discretizer.fit(df)
40
40
41
- testTransformerByGlobalCheckFunc[(Double )](df, model, " result" ) { rows =>
42
- val result = rows.map { r => Tuple1 (r.getDouble(0 )) }.toDF(" result" )
43
- val observedNumBuckets = result.select(" result" ).distinct.count
44
- assert(observedNumBuckets === numBuckets,
45
- " Observed number of buckets does not equal expected number of buckets." )
46
- val relativeError = discretizer.getRelativeError
47
- val numGoodBuckets = result.groupBy(" result" ).count
48
- .filter(s " abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}" ).count
49
- assert(numGoodBuckets === numBuckets,
50
- " Bucket sizes are not within expected relative error tolerance." )
41
+ testTransformerByGlobalCheckFunc[Double ](df, model, " result" ) { rows =>
42
+ val result = rows.map(_.getDouble(0 )).toDF(" result" ).cache()
43
+ try {
44
+ val observedNumBuckets = result.select(" result" ).distinct().count()
45
+ assert(observedNumBuckets === numBuckets,
46
+ " Observed number of buckets does not equal expected number of buckets." )
47
+ val relativeError = discretizer.getRelativeError
48
+ val numGoodBuckets = result.groupBy(" result" ).count()
49
+ .filter(s " abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}" )
50
+ .count()
51
+ assert(numGoodBuckets === numBuckets,
52
+ " Bucket sizes are not within expected relative error tolerance." )
53
+ } finally {
54
+ result.unpersist()
55
+ }
51
56
}
52
57
}
53
58
@@ -162,10 +167,10 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
162
167
val spark = this .spark
163
168
import spark .implicits ._
164
169
165
- val datasetSize = 100000
170
+ val datasetSize = 30000
166
171
val numBuckets = 5
167
- val data1 = Array .range(1 , 100001 , 1 ).map(_.toDouble)
168
- val data2 = Array .range(1 , 200000 , 2 ).map(_.toDouble)
172
+ val data1 = Array .range(1 , datasetSize + 1 , 1 ).map(_.toDouble)
173
+ val data2 = Array .range(1 , 2 * datasetSize , 2 ).map(_.toDouble)
169
174
val df = data1.zip(data2).toSeq.toDF(" input1" , " input2" )
170
175
171
176
val discretizer = new QuantileDiscretizer ()
@@ -175,20 +180,24 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
175
180
val model = discretizer.fit(df)
176
181
testTransformerByGlobalCheckFunc[(Double , Double )](df, model, " result1" , " result2" ) { rows =>
177
182
val result =
178
- rows.map { r => Tuple2 (r.getDouble(0 ), r.getDouble(1 )) }.toDF(" result1" , " result2" )
179
- val relativeError = discretizer.getRelativeError
180
- for (i <- 1 to 2 ) {
181
- val observedNumBuckets = result.select(" result" + i).distinct.count
182
- assert(observedNumBuckets === numBuckets,
183
- " Observed number of buckets does not equal expected number of buckets." )
184
-
185
- val numGoodBuckets = result
186
- .groupBy(" result" + i)
187
- .count
188
- .filter(s " abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}" )
189
- .count
190
- assert(numGoodBuckets === numBuckets,
191
- " Bucket sizes are not within expected relative error tolerance." )
183
+ rows.map(r => (r.getDouble(0 ), r.getDouble(1 ))).toDF(" result1" , " result2" ).cache()
184
+ try {
185
+ val relativeError = discretizer.getRelativeError
186
+ for (i <- 1 to 2 ) {
187
+ val observedNumBuckets = result.select(" result" + i).distinct().count()
188
+ assert(observedNumBuckets === numBuckets,
189
+ " Observed number of buckets does not equal expected number of buckets." )
190
+
191
+ val numGoodBuckets = result
192
+ .groupBy(" result" + i)
193
+ .count()
194
+ .filter(s " abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}" )
195
+ .count()
196
+ assert(numGoodBuckets === numBuckets,
197
+ " Bucket sizes are not within expected relative error tolerance." )
198
+ }
199
+ } finally {
200
+ result.unpersist()
192
201
}
193
202
}
194
203
}
0 commit comments