Skip to content

Commit d2a8723

Browse files
lwwmanningRobert Kruszewski
authored andcommitted
Support dynamic allocation without external shuffle service (apache-spark-on-k8s#427)
Allows dynamically scaling executors up and down without external shuffle service. Tracks shuffle locations to know if executors can be safely scaled down.
1 parent 84d8b05 commit d2a8723

File tree

11 files changed

+478
-98
lines changed

11 files changed

+478
-98
lines changed

core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.util.control.{ControlThrowable, NonFatal}
2525

2626
import com.codahale.metrics.{Gauge, MetricRegistry}
2727

28-
import org.apache.spark.internal.{config, Logging}
28+
import org.apache.spark.internal.Logging
2929
import org.apache.spark.internal.config._
3030
import org.apache.spark.metrics.source.Source
3131
import org.apache.spark.scheduler._
@@ -87,6 +87,7 @@ private[spark] class ExecutorAllocationManager(
8787
client: ExecutorAllocationClient,
8888
listenerBus: LiveListenerBus,
8989
conf: SparkConf,
90+
mapOutputTracker: MapOutputTrackerMaster,
9091
blockManagerMaster: BlockManagerMaster)
9192
extends Logging {
9293

@@ -114,6 +115,9 @@ private[spark] class ExecutorAllocationManager(
114115
private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds(
115116
"spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s")
116117

118+
private val inactiveShuffleExecutorIdleTimeoutS = conf.getTimeAsSeconds(
119+
"spark.dynamicAllocation.inactiveShuffleExecutorIdleTimeout", s"${Integer.MAX_VALUE}s")
120+
117121
// During testing, the methods to actually kill and add executors are mocked out
118122
private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
119123

@@ -210,12 +214,6 @@ private[spark] class ExecutorAllocationManager(
210214
if (cachedExecutorIdleTimeoutS < 0) {
211215
throw new SparkException("spark.dynamicAllocation.cachedExecutorIdleTimeout must be >= 0!")
212216
}
213-
// Require external shuffle service for dynamic allocation
214-
// Otherwise, we may lose shuffle files when killing executors
215-
if (!conf.get(config.SHUFFLE_SERVICE_ENABLED) && !testing) {
216-
throw new SparkException("Dynamic allocation of executors requires the external " +
217-
"shuffle service. You may enable this through spark.shuffle.service.enabled.")
218-
}
219217
if (tasksPerExecutorForFullParallelism == 0) {
220218
throw new SparkException("spark.executor.cores must not be < spark.task.cpus.")
221219
}
@@ -546,7 +544,7 @@ private[spark] class ExecutorAllocationManager(
546544
// has been reached, it will no longer be marked as idle. When new executors join,
547545
// however, we are no longer at the lower bound, and so we must mark executor X
548546
// as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951)
549-
executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle)
547+
checkForIdleExecutors()
550548
logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
551549
} else {
552550
logWarning(s"Duplicate executor $executorId has registered")
@@ -601,30 +599,44 @@ private[spark] class ExecutorAllocationManager(
601599
*/
602600
private def onExecutorIdle(executorId: String): Unit = synchronized {
603601
if (executorIds.contains(executorId)) {
604-
if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
602+
val hasActiveShuffleBlocks =
603+
mapOutputTracker.hasOutputsOnExecutor(executorId, activeOnly = true)
604+
if (!removeTimes.contains(executorId)
605+
&& !executorsPendingToRemove.contains(executorId)
606+
&& !hasActiveShuffleBlocks) {
605607
// Note that it is not necessary to query the executors since all the cached
606608
// blocks we are concerned with are reported to the driver. Note that this
607609
// does not include broadcast blocks.
608610
val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId)
611+
val hasAnyShuffleBlocks = mapOutputTracker.hasOutputsOnExecutor(executorId)
609612
val now = clock.getTimeMillis()
610-
val timeout = {
611-
if (hasCachedBlocks) {
612-
// Use a different timeout if the executor has cached blocks.
613-
now + cachedExecutorIdleTimeoutS * 1000
614-
} else {
615-
now + executorIdleTimeoutS * 1000
616-
}
617-
}
618-
val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow
619-
removeTimes(executorId) = realTimeout
613+
614+
// Use the maximum of all the timeouts that apply.
615+
val timeoutS = List(
616+
executorIdleTimeoutS,
617+
if (hasCachedBlocks) cachedExecutorIdleTimeoutS else 0,
618+
if (hasAnyShuffleBlocks) inactiveShuffleExecutorIdleTimeoutS else 0)
619+
.max
620+
621+
val expiryTime = now + timeoutS * 1000;
622+
val realExpiryTime = if (expiryTime <= 0) Long.MaxValue else expiryTime
623+
624+
removeTimes(executorId) = realExpiryTime
620625
logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
621-
s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)")
626+
s"scheduled to run on the executor (to expire in ${(realExpiryTime - now)/1000} seconds)")
622627
}
623628
} else {
624629
logWarning(s"Attempted to mark unknown executor $executorId idle")
625630
}
626631
}
627632

633+
/**
634+
* Check if any executors are now idle, and call the idle callback for them.
635+
*/
636+
private def checkForIdleExecutors(): Unit = synchronized {
637+
executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle)
638+
}
639+
628640
/**
629641
* Callback invoked when the specified executor is now running a task.
630642
* This resets all variables used for removing this executor.
@@ -659,6 +671,12 @@ private[spark] class ExecutorAllocationManager(
659671
// place the executors.
660672
private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])]
661673

674+
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
675+
// At the end of a job, trigger the callbacks for idle executors again to clean up executors
676+
// which we were keeping around only because they held active shuffle blocks.
677+
allocationManager.checkForIdleExecutors()
678+
}
679+
662680
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
663681
initializing = false
664682
val stageId = stageSubmitted.stageInfo.stageId

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration
2828
import scala.reflect.ClassTag
2929
import scala.util.control.NonFatal
3030

31+
import org.apache.spark.ExecutorShuffleStatus.ExecutorShuffleStatus
3132
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
3233
import org.apache.spark.internal.Logging
3334
import org.apache.spark.internal.config._
@@ -61,6 +62,13 @@ private class ShuffleStatus(numPartitions: Int) {
6162
// Exposed for testing
6263
val mapStatuses = new Array[MapStatus](numPartitions)
6364

65+
/**
66+
* Whether an active job in the [[org.apache.spark.scheduler.DAGScheduler]] depends on this.
67+
* If dynamic allocation is enabled, then executors that do not contain active shuffles may
68+
* eventually be surrendered by the [[ExecutorAllocationManager]].
69+
*/
70+
var isActive = true
71+
6472
/**
6573
* The cached result of serializing the map statuses array. This cache is lazily populated when
6674
* [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed.
@@ -80,17 +88,24 @@ private class ShuffleStatus(numPartitions: Int) {
8088
/**
8189
* Counter tracking the number of partitions that have output. This is a performance optimization
8290
* to avoid having to count the number of non-null entries in the `mapStatuses` array and should
83-
* be equivalent to`mapStatuses.count(_ ne null)`.
91+
* be equivalent to `mapStatuses.count(_ ne null)`.
8492
*/
8593
private[this] var _numAvailableOutputs: Int = 0
8694

95+
/**
96+
* Cached set of executorIds on which outputs exist. This is a performance optimization to avoid
97+
* having to repeatedly iterate over ever element in the `mapStatuses` array and should be
98+
* equivalent to `mapStatuses.map(_.location.executorId).groupBy(x => x).mapValues(_.length)`.
99+
*/
100+
private[this] val _numOutputsPerExecutorId = HashMap[String, Int]().withDefaultValue(0)
101+
87102
/**
88103
* Register a map output. If there is already a registered location for the map output then it
89104
* will be replaced by the new location.
90105
*/
91106
def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
92107
if (mapStatuses(mapId) == null) {
93-
_numAvailableOutputs += 1
108+
incrementNumAvailableOutputs(status.location)
94109
invalidateSerializedMapOutputStatusCache()
95110
}
96111
mapStatuses(mapId) = status
@@ -103,7 +118,7 @@ private class ShuffleStatus(numPartitions: Int) {
103118
*/
104119
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
105120
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
106-
_numAvailableOutputs -= 1
121+
decrementNumAvailableOutputs(bmAddress)
107122
mapStatuses(mapId) = null
108123
invalidateSerializedMapOutputStatusCache()
109124
}
@@ -133,13 +148,21 @@ private class ShuffleStatus(numPartitions: Int) {
133148
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
134149
for (mapId <- 0 until mapStatuses.length) {
135150
if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
136-
_numAvailableOutputs -= 1
151+
decrementNumAvailableOutputs(mapStatuses(mapId).location)
137152
mapStatuses(mapId) = null
138153
invalidateSerializedMapOutputStatusCache()
139154
}
140155
}
141156
}
142157

158+
def hasOutputsOnExecutor(execId: String): Boolean = synchronized {
159+
_numOutputsPerExecutorId(execId) > 0
160+
}
161+
162+
def executorsWithOutputs(): Set[String] = synchronized {
163+
_numOutputsPerExecutorId.keySet.toSet
164+
}
165+
143166
/**
144167
* Number of partitions that have shuffle outputs.
145168
*/
@@ -192,6 +215,22 @@ private class ShuffleStatus(numPartitions: Int) {
192215
f(mapStatuses)
193216
}
194217

218+
private[this] def incrementNumAvailableOutputs(bmAddress: BlockManagerId): Unit = synchronized {
219+
_numOutputsPerExecutorId(bmAddress.executorId) += 1
220+
_numAvailableOutputs += 1
221+
}
222+
223+
private[this] def decrementNumAvailableOutputs(bmAddress: BlockManagerId): Unit = synchronized {
224+
assert(_numOutputsPerExecutorId(bmAddress.executorId) >= 1,
225+
s"Tried to remove non-existent output from ${bmAddress.executorId}")
226+
if (_numOutputsPerExecutorId(bmAddress.executorId) == 1) {
227+
_numOutputsPerExecutorId.remove(bmAddress.executorId)
228+
} else {
229+
_numOutputsPerExecutorId(bmAddress.executorId) -= 1
230+
}
231+
_numAvailableOutputs -= 1
232+
}
233+
195234
/**
196235
* Clears the cached serialized map output statuses.
197236
*/
@@ -306,6 +345,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
306345
def stop() {}
307346
}
308347

348+
private[spark] object ExecutorShuffleStatus extends Enumeration {
349+
type ExecutorShuffleStatus = Value
350+
val Active, Inactive, Unknown = Value
351+
}
352+
309353
/**
310354
* Driver-side class that keeps track of the location of the map output of a stage.
311355
*
@@ -453,6 +497,26 @@ private[spark] class MapOutputTrackerMaster(
453497
}
454498
}
455499

500+
def markShuffleInactive(shuffleId: Int): Unit = {
501+
shuffleStatuses.get(shuffleId) match {
502+
case Some(shuffleStatus) =>
503+
shuffleStatus.isActive = false
504+
case None =>
505+
throw new SparkException(
506+
s"markShuffleInactive called for nonexistent shuffle ID $shuffleId.")
507+
}
508+
}
509+
510+
def markShuffleActive(shuffleId: Int): Unit = {
511+
shuffleStatuses.get(shuffleId) match {
512+
case Some(shuffleStatus) =>
513+
shuffleStatus.isActive = true
514+
case None =>
515+
throw new SparkException(
516+
s"markShuffleActive called for nonexistent shuffle ID $shuffleId.")
517+
}
518+
}
519+
456520
/**
457521
* Removes all shuffle outputs associated with this host. Note that this will also remove
458522
* outputs which are served by an external shuffle server (if one exists).
@@ -472,6 +536,12 @@ private[spark] class MapOutputTrackerMaster(
472536
incrementEpoch()
473537
}
474538

539+
def hasOutputsOnExecutor(execId: String, activeOnly: Boolean = false): Boolean = {
540+
shuffleStatuses.valuesIterator.exists { status =>
541+
status.hasOutputsOnExecutor(execId) && (!activeOnly || status.isActive)
542+
}
543+
}
544+
475545
/** Check if the given shuffle is being tracked */
476546
def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)
477547

@@ -577,6 +647,20 @@ private[spark] class MapOutputTrackerMaster(
577647
}
578648
}
579649

650+
/**
651+
* Return the set of executors that contain tracked shuffle files, with a status of
652+
* [[ExecutorShuffleStatus.Inactive]] iff all shuffle files on that executor are marked inactive.
653+
*
654+
* @return a map of executor IDs to their corresponding [[ExecutorShuffleStatus]]
655+
*/
656+
def getExecutorShuffleStatus: scala.collection.Map[String, ExecutorShuffleStatus] = {
657+
shuffleStatuses.values
658+
.flatMap(status => status.executorsWithOutputs().map(_ -> status.isActive))
659+
.groupBy(_._1)
660+
.mapValues(_.exists(_._2))
661+
.mapValues(if (_) ExecutorShuffleStatus.Active else ExecutorShuffleStatus.Inactive)
662+
}
663+
580664
/**
581665
* Return a list of locations that each have fraction of map output greater than the specified
582666
* threshold.

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ class SparkContext(config: SparkConf) extends Logging {
546546
case b: ExecutorAllocationClient =>
547547
Some(new ExecutorAllocationManager(
548548
schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf,
549+
_env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
549550
_env.blockManager.master))
550551
case _ =>
551552
None

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ private[spark] class DAGScheduler(
397397
shuffleIdToMapStage(shuffleDep.shuffleId) = stage
398398
updateJobIdStageIdMaps(jobId, stage)
399399

400-
if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
400+
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
401+
mapOutputTracker.markShuffleActive(shuffleDep.shuffleId)
402+
} else {
401403
// Kind of ugly: need to register RDDs with the cache and map output tracker here
402404
// since we can't do it in the RDD constructor because # of partitions is unknown
403405
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -627,6 +629,7 @@ private[spark] class DAGScheduler(
627629
}
628630
for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
629631
shuffleIdToMapStage.remove(k)
632+
mapOutputTracker.markShuffleInactive(k)
630633
}
631634
if (waitingStages.contains(stage)) {
632635
logDebug("Removing stage %d from waiting set.".format(stageId))
@@ -1367,6 +1370,31 @@ private[spark] class DAGScheduler(
13671370
case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
13681371
case _ =>
13691372
}
1373+
1374+
// Make sure shuffle outputs are registered before we post the event so that
1375+
// handlers can act on up-to-date shuffle information.
1376+
event.reason match {
1377+
case Success =>
1378+
task match {
1379+
case smt: ShuffleMapTask =>
1380+
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
1381+
val status = event.result.asInstanceOf[MapStatus]
1382+
val execId = status.location.executorId
1383+
logDebug("Registering shuffle output on executor " + execId)
1384+
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
1385+
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
1386+
} else {
1387+
// The epoch of the task is acceptable (i.e., the task was launched after the most
1388+
// recent failure we're aware of for the executor), so mark the task's output as
1389+
// available.
1390+
mapOutputTracker.registerMapOutput(
1391+
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
1392+
}
1393+
case _ =>
1394+
}
1395+
case _ =>
1396+
}
1397+
13701398
postTaskEnd(event)
13711399

13721400
event.reason match {
@@ -1418,21 +1446,12 @@ private[spark] class DAGScheduler(
14181446
logInfo("Ignoring result from " + rt + " because its job has finished")
14191447
}
14201448

1421-
case smt: ShuffleMapTask =>
1449+
case _: ShuffleMapTask =>
14221450
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
14231451
shuffleStage.pendingPartitions -= task.partitionId
14241452
val status = event.result.asInstanceOf[MapStatus]
14251453
val execId = status.location.executorId
14261454
logDebug("ShuffleMapTask finished on " + execId)
1427-
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
1428-
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
1429-
} else {
1430-
// The epoch of the task is acceptable (i.e., the task was launched after the most
1431-
// recent failure we're aware of for the executor), so mark the task's output as
1432-
// available.
1433-
mapOutputTracker.registerMapOutput(
1434-
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
1435-
}
14361455

14371456
if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
14381457
markStageAsFinished(shuffleStage)

0 commit comments

Comments
 (0)