@@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration
28
28
import scala .reflect .ClassTag
29
29
import scala .util .control .NonFatal
30
30
31
+ import org .apache .spark .ExecutorShuffleStatus .ExecutorShuffleStatus
31
32
import org .apache .spark .broadcast .{Broadcast , BroadcastManager }
32
33
import org .apache .spark .internal .Logging
33
34
import org .apache .spark .internal .config ._
@@ -61,6 +62,13 @@ private class ShuffleStatus(numPartitions: Int) {
61
62
// Exposed for testing
62
63
val mapStatuses = new Array [MapStatus ](numPartitions)
63
64
65
+ /**
66
+ * Whether an active job in the [[org.apache.spark.scheduler.DAGScheduler ]] depends on this.
67
+ * If dynamic allocation is enabled, then executors that do not contain active shuffles may
68
+ * eventually be surrendered by the [[ExecutorAllocationManager ]].
69
+ */
70
+ var isActive = true
71
+
64
72
/**
65
73
* The cached result of serializing the map statuses array. This cache is lazily populated when
66
74
* [[serializedMapStatus ]] is called. The cache is invalidated when map outputs are removed.
@@ -80,17 +88,24 @@ private class ShuffleStatus(numPartitions: Int) {
80
88
/**
81
89
* Counter tracking the number of partitions that have output. This is a performance optimization
82
90
* to avoid having to count the number of non-null entries in the `mapStatuses` array and should
83
- * be equivalent to`mapStatuses.count(_ ne null)`.
91
+ * be equivalent to `mapStatuses.count(_ ne null)`.
84
92
*/
85
93
private [this ] var _numAvailableOutputs : Int = 0
86
94
95
+ /**
96
+ * Cached set of executorIds on which outputs exist. This is a performance optimization to avoid
97
+ * having to repeatedly iterate over ever element in the `mapStatuses` array and should be
98
+ * equivalent to `mapStatuses.map(_.location.executorId).groupBy(x => x).mapValues(_.length)`.
99
+ */
100
+ private [this ] val _numOutputsPerExecutorId = HashMap [String , Int ]().withDefaultValue(0 )
101
+
87
102
/**
88
103
* Register a map output. If there is already a registered location for the map output then it
89
104
* will be replaced by the new location.
90
105
*/
91
106
def addMapOutput (mapId : Int , status : MapStatus ): Unit = synchronized {
92
107
if (mapStatuses(mapId) == null ) {
93
- _numAvailableOutputs += 1
108
+ incrementNumAvailableOutputs(status.location)
94
109
invalidateSerializedMapOutputStatusCache()
95
110
}
96
111
mapStatuses(mapId) = status
@@ -103,7 +118,7 @@ private class ShuffleStatus(numPartitions: Int) {
103
118
*/
104
119
def removeMapOutput (mapId : Int , bmAddress : BlockManagerId ): Unit = synchronized {
105
120
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
106
- _numAvailableOutputs -= 1
121
+ decrementNumAvailableOutputs(bmAddress)
107
122
mapStatuses(mapId) = null
108
123
invalidateSerializedMapOutputStatusCache()
109
124
}
@@ -133,13 +148,21 @@ private class ShuffleStatus(numPartitions: Int) {
133
148
def removeOutputsByFilter (f : (BlockManagerId ) => Boolean ): Unit = synchronized {
134
149
for (mapId <- 0 until mapStatuses.length) {
135
150
if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
136
- _numAvailableOutputs -= 1
151
+ decrementNumAvailableOutputs(mapStatuses(mapId).location)
137
152
mapStatuses(mapId) = null
138
153
invalidateSerializedMapOutputStatusCache()
139
154
}
140
155
}
141
156
}
142
157
158
+ def hasOutputsOnExecutor (execId : String ): Boolean = synchronized {
159
+ _numOutputsPerExecutorId(execId) > 0
160
+ }
161
+
162
+ def executorsWithOutputs (): Set [String ] = synchronized {
163
+ _numOutputsPerExecutorId.keySet.toSet
164
+ }
165
+
143
166
/**
144
167
* Number of partitions that have shuffle outputs.
145
168
*/
@@ -192,6 +215,22 @@ private class ShuffleStatus(numPartitions: Int) {
192
215
f(mapStatuses)
193
216
}
194
217
218
+ private [this ] def incrementNumAvailableOutputs (bmAddress : BlockManagerId ): Unit = synchronized {
219
+ _numOutputsPerExecutorId(bmAddress.executorId) += 1
220
+ _numAvailableOutputs += 1
221
+ }
222
+
223
+ private [this ] def decrementNumAvailableOutputs (bmAddress : BlockManagerId ): Unit = synchronized {
224
+ assert(_numOutputsPerExecutorId(bmAddress.executorId) >= 1 ,
225
+ s " Tried to remove non-existent output from ${bmAddress.executorId}" )
226
+ if (_numOutputsPerExecutorId(bmAddress.executorId) == 1 ) {
227
+ _numOutputsPerExecutorId.remove(bmAddress.executorId)
228
+ } else {
229
+ _numOutputsPerExecutorId(bmAddress.executorId) -= 1
230
+ }
231
+ _numAvailableOutputs -= 1
232
+ }
233
+
195
234
/**
196
235
* Clears the cached serialized map output statuses.
197
236
*/
@@ -306,6 +345,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
306
345
def stop () {}
307
346
}
308
347
348
+ private [spark] object ExecutorShuffleStatus extends Enumeration {
349
+ type ExecutorShuffleStatus = Value
350
+ val Active, Inactive, Unknown = Value
351
+ }
352
+
309
353
/**
310
354
* Driver-side class that keeps track of the location of the map output of a stage.
311
355
*
@@ -453,6 +497,26 @@ private[spark] class MapOutputTrackerMaster(
453
497
}
454
498
}
455
499
500
+ def markShuffleInactive (shuffleId : Int ): Unit = {
501
+ shuffleStatuses.get(shuffleId) match {
502
+ case Some (shuffleStatus) =>
503
+ shuffleStatus.isActive = false
504
+ case None =>
505
+ throw new SparkException (
506
+ s " markShuffleInactive called for nonexistent shuffle ID $shuffleId. " )
507
+ }
508
+ }
509
+
510
+ def markShuffleActive (shuffleId : Int ): Unit = {
511
+ shuffleStatuses.get(shuffleId) match {
512
+ case Some (shuffleStatus) =>
513
+ shuffleStatus.isActive = true
514
+ case None =>
515
+ throw new SparkException (
516
+ s " markShuffleActive called for nonexistent shuffle ID $shuffleId. " )
517
+ }
518
+ }
519
+
456
520
/**
457
521
* Removes all shuffle outputs associated with this host. Note that this will also remove
458
522
* outputs which are served by an external shuffle server (if one exists).
@@ -472,6 +536,12 @@ private[spark] class MapOutputTrackerMaster(
472
536
incrementEpoch()
473
537
}
474
538
539
+ def hasOutputsOnExecutor (execId : String , activeOnly : Boolean = false ): Boolean = {
540
+ shuffleStatuses.valuesIterator.exists { status =>
541
+ status.hasOutputsOnExecutor(execId) && (! activeOnly || status.isActive)
542
+ }
543
+ }
544
+
475
545
/** Check if the given shuffle is being tracked */
476
546
def containsShuffle (shuffleId : Int ): Boolean = shuffleStatuses.contains(shuffleId)
477
547
@@ -577,6 +647,20 @@ private[spark] class MapOutputTrackerMaster(
577
647
}
578
648
}
579
649
650
+ /**
651
+ * Return the set of executors that contain tracked shuffle files, with a status of
652
+ * [[ExecutorShuffleStatus.Inactive ]] iff all shuffle files on that executor are marked inactive.
653
+ *
654
+ * @return a map of executor IDs to their corresponding [[ExecutorShuffleStatus ]]
655
+ */
656
+ def getExecutorShuffleStatus : scala.collection.Map [String , ExecutorShuffleStatus ] = {
657
+ shuffleStatuses.values
658
+ .flatMap(status => status.executorsWithOutputs().map(_ -> status.isActive))
659
+ .groupBy(_._1)
660
+ .mapValues(_.exists(_._2))
661
+ .mapValues(if (_) ExecutorShuffleStatus .Active else ExecutorShuffleStatus .Inactive )
662
+ }
663
+
580
664
/**
581
665
* Return a list of locations that each have fraction of map output greater than the specified
582
666
* threshold.
0 commit comments