@@ -20,7 +20,7 @@ package org.apache.spark
20
20
import java .io .File
21
21
import java .net .{MalformedURLException , URI }
22
22
import java .nio .charset .StandardCharsets
23
- import java .util .concurrent .{Semaphore , TimeUnit }
23
+ import java .util .concurrent .{CountDownLatch , Semaphore , TimeUnit }
24
24
25
25
import scala .concurrent .duration ._
26
26
@@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
498
498
499
499
test(" Cancelling stages/jobs with custom reasons." ) {
500
500
sc = new SparkContext (new SparkConf ().setAppName(" test" ).setMaster(" local" ))
501
+ sc.setLocalProperty(SparkContext .SPARK_JOB_INTERRUPT_ON_CANCEL , " true" )
501
502
val REASON = " You shall not pass"
502
- val slices = 10
503
503
504
- val listener = new SparkListener {
505
- override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
506
- if (SparkContextSuite .cancelStage) {
507
- eventually(timeout(10 .seconds)) {
508
- assert(SparkContextSuite .isTaskStarted)
504
+ for (cancelWhat <- Seq (" stage" , " job" )) {
505
+ // This countdown latch used to make sure stage or job canceled in listener
506
+ val latch = new CountDownLatch (1 )
507
+
508
+ val listener = cancelWhat match {
509
+ case " stage" =>
510
+ new SparkListener {
511
+ override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
512
+ sc.cancelStage(taskStart.stageId, REASON )
513
+ latch.countDown()
514
+ }
509
515
}
510
- sc.cancelStage(taskStart.stageId, REASON )
511
- SparkContextSuite .cancelStage = false
512
- SparkContextSuite .semaphore.release(slices)
513
- }
514
- }
515
-
516
- override def onJobStart (jobStart : SparkListenerJobStart ): Unit = {
517
- if (SparkContextSuite .cancelJob) {
518
- eventually(timeout(10 .seconds)) {
519
- assert(SparkContextSuite .isTaskStarted)
516
+ case " job" =>
517
+ new SparkListener {
518
+ override def onJobStart (jobStart : SparkListenerJobStart ): Unit = {
519
+ sc.cancelJob(jobStart.jobId, REASON )
520
+ latch.countDown()
521
+ }
520
522
}
521
- sc.cancelJob(jobStart.jobId, REASON )
522
- SparkContextSuite .cancelJob = false
523
- SparkContextSuite .semaphore.release(slices)
524
- }
525
523
}
526
- }
527
- sc.addSparkListener(listener)
528
-
529
- for (cancelWhat <- Seq (" stage" , " job" )) {
530
- SparkContextSuite .semaphore.drainPermits()
531
- SparkContextSuite .isTaskStarted = false
532
- SparkContextSuite .cancelStage = (cancelWhat == " stage" )
533
- SparkContextSuite .cancelJob = (cancelWhat == " job" )
524
+ sc.addSparkListener(listener)
534
525
535
526
val ex = intercept[SparkException ] {
536
- sc.range(0 , 10000L , numSlices = slices ).mapPartitions { x =>
537
- SparkContextSuite .isTaskStarted = true
538
- // Block waiting for the listener to cancel the stage or job.
539
- SparkContextSuite .semaphore.acquire()
527
+ sc.range(0 , 10000L , numSlices = 10 ).mapPartitions { x =>
528
+ x. synchronized {
529
+ x.wait()
530
+ }
540
531
x
541
532
}.count()
542
533
}
@@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
550
541
fail(" Expected the cause to be SparkException, got " + cause.toString() + " instead." )
551
542
}
552
543
544
+ latch.await(20 , TimeUnit .SECONDS )
553
545
eventually(timeout(20 .seconds)) {
554
546
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0 )
555
547
}
548
+ sc.removeSparkListener(listener)
556
549
}
557
550
}
558
551
@@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
637
630
}
638
631
639
632
object SparkContextSuite {
640
- @ volatile var cancelJob = false
641
- @ volatile var cancelStage = false
642
633
@ volatile var isTaskStarted = false
643
634
@ volatile var taskKilled = false
644
635
@ volatile var taskSucceeded = false
0 commit comments