17
17
18
18
package org .apache .spark .sql
19
19
20
- import java .util .concurrent .{CountDownLatch , TimeUnit }
21
-
22
20
import scala .concurrent .duration ._
23
21
import scala .math .abs
24
22
import scala .util .Random
25
23
26
24
import org .scalatest .concurrent .Eventually
27
25
28
- import org .apache .spark .{SparkContext , SparkException }
29
- import org .apache .spark .scheduler .{SparkListener , SparkListenerTaskStart }
26
+ import org .apache .spark .{SparkException , TaskContext }
27
+ import org .apache .spark .scheduler .{SparkListener , SparkListenerJobStart }
30
28
import org .apache .spark .sql .functions ._
31
29
import org .apache .spark .sql .internal .SQLConf
32
30
import org .apache .spark .sql .test .SharedSQLContext
@@ -154,53 +152,39 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
154
152
}
155
153
156
154
test(" Cancelling stage in a query with Range." ) {
157
- // Save and restore the value because SparkContext is shared
158
- val savedInterruptOnCancel = sparkContext
159
- .getLocalProperty(SparkContext .SPARK_JOB_INTERRUPT_ON_CANCEL )
160
-
161
- try {
162
- sparkContext.setLocalProperty(SparkContext .SPARK_JOB_INTERRUPT_ON_CANCEL , " true" )
163
-
164
- for (codegen <- Seq (true , false )) {
165
- // This countdown latch used to make sure with all the stages cancelStage called in listener
166
- val latch = new CountDownLatch (2 )
167
-
168
- val listener = new SparkListener {
169
- override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
170
- sparkContext.cancelStage(taskStart.stageId)
171
- latch.countDown()
172
- }
155
+ val listener = new SparkListener {
156
+ override def onJobStart (jobStart : SparkListenerJobStart ): Unit = {
157
+ eventually(timeout(10 .seconds), interval(1 .millis)) {
158
+ assert(DataFrameRangeSuite .stageToKill > 0 )
173
159
}
160
+ sparkContext.cancelStage(DataFrameRangeSuite .stageToKill)
161
+ }
162
+ }
174
163
175
- sparkContext.addSparkListener(listener)
176
- withSQLConf(SQLConf .WHOLESTAGE_CODEGEN_ENABLED .key -> codegen.toString()) {
177
- val ex = intercept[SparkException ] {
178
- sparkContext.range(0 , 10000L , numSlices = 10 ).mapPartitions { x =>
179
- x.synchronized {
180
- x.wait()
181
- }
182
- x
183
- }.toDF(" id" ).agg(sum(" id" )).collect()
184
- }
185
- ex.getCause() match {
186
- case null =>
187
- assert(ex.getMessage().contains(" cancelled" ))
188
- case cause : SparkException =>
189
- assert(cause.getMessage().contains(" cancelled" ))
190
- case cause : Throwable =>
191
- fail(" Expected the cause to be SparkException, got " + cause.toString() + " instead." )
192
- }
164
+ sparkContext.addSparkListener(listener)
165
+ for (codegen <- Seq (true , false )) {
166
+ withSQLConf(SQLConf .WHOLESTAGE_CODEGEN_ENABLED .key -> codegen.toString()) {
167
+ DataFrameRangeSuite .stageToKill = - 1
168
+ val ex = intercept[SparkException ] {
169
+ spark.range(0 , 100000000000L , 1 , 1 ).map { x =>
170
+ DataFrameRangeSuite .stageToKill = TaskContext .get().stageId()
171
+ x
172
+ }.toDF(" id" ).agg(sum(" id" )).collect()
193
173
}
194
- latch.await(20 , TimeUnit .SECONDS )
195
- eventually(timeout(20 .seconds)) {
196
- assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0 )
174
+ ex.getCause() match {
175
+ case null =>
176
+ assert(ex.getMessage().contains(" cancelled" ))
177
+ case cause : SparkException =>
178
+ assert(cause.getMessage().contains(" cancelled" ))
179
+ case cause : Throwable =>
180
+ fail(" Expected the cause to be SparkException, got " + cause.toString() + " instead." )
197
181
}
198
- sparkContext.removeSparkListener(listener)
199
182
}
200
- } finally {
201
- sparkContext.setLocalProperty( SparkContext . SPARK_JOB_INTERRUPT_ON_CANCEL ,
202
- savedInterruptOnCancel)
183
+ eventually(timeout( 20 .seconds)) {
184
+ assert( sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0 )
185
+ }
203
186
}
187
+ sparkContext.removeSparkListener(listener)
204
188
}
205
189
206
190
test(" SPARK-20430 Initialize Range parameters in a driver side" ) {
@@ -220,3 +204,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
220
204
}
221
205
}
222
206
}
207
+
208
+ object DataFrameRangeSuite {
209
+ @ volatile var stageToKill = - 1
210
+ }
0 commit comments