Skip to content

Commit 92fe984

Browse files
rynorrisRobert Kruszewski
authored andcommitted
Track active shuffles by stage (apache-spark-on-k8s#446)
1 parent 3f435fd commit 92fe984

File tree

3 files changed

+142
-2
lines changed

3 files changed

+142
-2
lines changed

core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,7 @@ private[spark] class ExecutorAllocationManager(
677677
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
678678
// At the end of a job, trigger the callbacks for idle executors again to clean up executors
679679
// which we were keeping around only because they held active shuffle blocks.
680+
logDebug("Checking for idle executors at end of job")
680681
allocationManager.checkForIdleExecutors()
681682
}
682683

@@ -727,6 +728,11 @@ private[spark] class ExecutorAllocationManager(
727728
if (stageIdToNumTasks.isEmpty && stageIdToNumSpeculativeTasks.isEmpty) {
728729
allocationManager.onSchedulerQueueEmpty()
729730
}
731+
732+
// Trigger the callbacks for idle executors again to clean up executors
733+
// which we were keeping around only because they held active shuffle blocks.
734+
logDebug("Checking for idle executors at end of stage")
735+
allocationManager.checkForIdleExecutors()
730736
}
731737
}
732738

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.util.function.BiFunction
2525

2626
import scala.annotation.tailrec
2727
import scala.collection.Map
28-
import scala.collection.mutable.{ArrayStack, HashMap, HashSet}
28+
import scala.collection.mutable.{ArrayStack, HashMap, HashSet, Set}
2929
import scala.concurrent.duration._
3030
import scala.language.existentials
3131
import scala.language.postfixOps
@@ -149,6 +149,13 @@ private[spark] class DAGScheduler(
149149
* the shuffle data will be in the MapOutputTracker).
150150
*/
151151
private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage]
152+
153+
/**
154+
* Mapping from shuffle dependency ID to the IDs of the stages which depend on the shuffle data.
155+
* Used to track when shuffle data becomes no longer active.
156+
*/
157+
private[scheduler] val shuffleIdToDependentStages = new HashMap[Int, Set[Int]]
158+
152159
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
153160

154161
// Stages we need to run whose parents aren't done
@@ -396,6 +403,7 @@ private[spark] class DAGScheduler(
396403
stageIdToStage(id) = stage
397404
shuffleIdToMapStage(shuffleDep.shuffleId) = stage
398405
updateJobIdStageIdMaps(jobId, stage)
406+
updateShuffleDependenciesMap(stage)
399407

400408
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
401409
mapOutputTracker.markShuffleActive(shuffleDep.shuffleId)
@@ -455,6 +463,7 @@ private[spark] class DAGScheduler(
455463
val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
456464
stageIdToStage(id) = stage
457465
updateJobIdStageIdMaps(jobId, stage)
466+
updateShuffleDependenciesMap(stage)
458467
stage
459468
}
460469

@@ -601,6 +610,17 @@ private[spark] class DAGScheduler(
601610
updateJobIdStageIdMapsList(List(stage))
602611
}
603612

613+
/**
614+
* Registers the shuffle dependencies of the given stage.
615+
*/
616+
private def updateShuffleDependenciesMap(stage: Stage): Unit = {
617+
getShuffleDependencies(stage.rdd).foreach { shuffleDep =>
618+
val shuffleId = shuffleDep.shuffleId
619+
logDebug("Tracking that stage " + stage.id + " depends on shuffle " + shuffleId)
620+
shuffleIdToDependentStages.getOrElseUpdate(shuffleId, Set.empty) += stage.id
621+
}
622+
}
623+
604624
/**
605625
* Removes state for job and any stages that are not needed by any other job. Does not
606626
* handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
@@ -1893,6 +1913,24 @@ private[spark] class DAGScheduler(
18931913
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
18941914
case _ => "Unknown"
18951915
}
1916+
1917+
getShuffleDependencies(stage.rdd).foreach { shuffleDep =>
1918+
val shuffleId = shuffleDep.shuffleId
1919+
if (!shuffleIdToDependentStages.contains(shuffleId)) {
1920+
logDebug("Stage finished with untracked shuffle dependency " + shuffleId)
1921+
} else {
1922+
var dependentStages = shuffleIdToDependentStages(shuffleId)
1923+
dependentStages -= stage.id;
1924+
logDebug("Stage " + stage.id + " finished. " +
1925+
"Shuffle " + shuffleId + " now has dependencies " + dependentStages)
1926+
if (dependentStages.isEmpty) {
1927+
logDebug("Shuffle " + shuffleId + " is no longer needed. Marking it inactive.")
1928+
shuffleIdToDependentStages.remove(shuffleId)
1929+
mapOutputTracker.markShuffleInactive(shuffleId)
1930+
}
1931+
}
1932+
}
1933+
18961934
if (errorMessage.isEmpty) {
18971935
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
18981936
stage.latestInfo.completionTime = Some(clock.getTimeMillis())

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.rdd.{DeterministicLevel, RDD}
3636
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
3737
import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
3838
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
39-
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils}
39+
import org.apache.spark.util._
4040

4141
class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
4242
extends DAGSchedulerEventProcessLoop(dagScheduler) {
@@ -2195,6 +2195,102 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
21952195
assertDataStructuresEmpty()
21962196
}
21972197

2198+
test("stage level active shuffle tracking") {
2199+
// We will have 3 stages depending on each other.
2200+
// The second stage is composed of 2 RDDs to check we're tracking shuffle up the chain.
2201+
val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
2202+
val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(1))
2203+
val shuffleId1 = shuffleDep1.shuffleId
2204+
val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
2205+
val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(1))
2206+
val shuffleId2 = shuffleDep2.shuffleId
2207+
val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep2), tracker = mapOutputTracker)
2208+
val intermediateDep = new OneToOneDependency(intermediateRdd)
2209+
val reduceRdd = new MyRDD(sc, 1, List(intermediateDep), tracker = mapOutputTracker)
2210+
2211+
// Submit the job.
2212+
// Both shuffles should become active.
2213+
submit(reduceRdd, Array(0))
2214+
assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === true)
2215+
assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === true)
2216+
2217+
// Complete the first stage.
2218+
// Both shuffles remain active.
2219+
completeShuffleMapStageSuccessfully(0, 0, 2)
2220+
assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === true)
2221+
assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === true)
2222+
2223+
// Complete the second stage.
2224+
// Shuffle 1 is no longer needed and should become inactive.
2225+
completeShuffleMapStageSuccessfully(1, 0, 1)
2226+
assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === false)
2227+
assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === true)
2228+
2229+
// Complete the results stage.
2230+
// Both shuffles are no longer needed and should become inactive.
2231+
completeNextResultStageWithSuccess(2, 0)
2232+
assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === false)
2233+
assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === false)
2234+
2235+
// Double check results.
2236+
assert(results === Map(0 -> 42))
2237+
results.clear()
2238+
assertDataStructuresEmpty()
2239+
}
2240+
2241+
test("stage level active shuffle tracking with multiple dependents") {
2242+
// We will have a diamond shape dependency.
2243+
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
2244+
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
2245+
val shuffleId = shuffleDep.shuffleId
2246+
val intermediateRdd1 = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
2247+
val intermediateRdd2 = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
2248+
val intermediateDep1 = new ShuffleDependency(intermediateRdd1, new HashPartitioner(1))
2249+
val intermediateDep2 = new ShuffleDependency(intermediateRdd2, new HashPartitioner(1))
2250+
val reduceRdd =
2251+
new MyRDD(sc, 1, List(intermediateDep1, intermediateDep2), tracker = mapOutputTracker)
2252+
2253+
// Submit the job.
2254+
// Shuffle becomes active.
2255+
submit(reduceRdd, Array(0))
2256+
assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === true)
2257+
2258+
// Complete the shuffle stage.
2259+
// Shuffle remains active.
2260+
completeShuffleMapStageSuccessfully(0, 0, 2)
2261+
assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === true)
2262+
2263+
// Complete first intermediate stage.
2264+
// Shuffle is still active.
2265+
val stageAttempt = taskSets(1)
2266+
checkStageId(1, 0, stageAttempt)
2267+
complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map {
2268+
case (task, idx) =>
2269+
(Success, makeMapStatus("host" + ('A' + idx).toChar, 1))
2270+
}.toSeq)
2271+
assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === true)
2272+
2273+
// Complete second intermediate stage.
2274+
// Shuffle is no longer active.
2275+
val stageAttempt2 = taskSets(2)
2276+
checkStageId(2, 0, stageAttempt2)
2277+
complete(stageAttempt2, stageAttempt2.tasks.zipWithIndex.map {
2278+
case (task, idx) =>
2279+
(Success, makeMapStatus("host" + ('A' + idx).toChar, 1))
2280+
}.toSeq)
2281+
assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === false)
2282+
2283+
// Complete the results stage.
2284+
// Shuffle is still inactive.
2285+
completeNextResultStageWithSuccess(3, 0)
2286+
assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === false)
2287+
2288+
// Double check results.
2289+
assert(results === Map(0 -> 42))
2290+
results.clear()
2291+
assertDataStructuresEmpty()
2292+
}
2293+
21982294
test("map stage submission with fetch failure") {
21992295
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
22002296
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))

0 commit comments

Comments
 (0)