Skip to content

Commit efd0036

Browse files
GuoChenzhaocloud-fan
authored andcommitted
[SPARK-22537][CORE] Aggregation of map output statistics on driver faces single point bottleneck
## What changes were proposed in this pull request? In adaptive execution, the map output statistics of all mappers will be aggregated after previous stage is successfully executed. Driver takes the aggregation job while it will get slow when the number of `mapper * shuffle partitions` is large, since it only uses single thread to compute. This PR uses multi-thread to deal with this single point bottleneck. ## How was this patch tested? Test cases are in `MapOutputTrackerSuite.scala` Author: GuoChenzhao <[email protected]> Author: gczsjdy <[email protected]> Closes #19763 from gczsjdy/single_point_mapstatistics.
1 parent 449e26e commit efd0036

File tree

3 files changed

+91
-3
lines changed

3 files changed

+91
-3
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2323

2424
import scala.collection.JavaConverters._
2525
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
26+
import scala.concurrent.{ExecutionContext, Future}
27+
import scala.concurrent.duration.Duration
2628
import scala.reflect.ClassTag
2729
import scala.util.control.NonFatal
2830

2931
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
3032
import org.apache.spark.internal.Logging
33+
import org.apache.spark.internal.config._
3134
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
3235
import org.apache.spark.scheduler.MapStatus
3336
import org.apache.spark.shuffle.MetadataFetchFailedException
@@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster(
472475
shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
473476
}
474477

478+
/**
479+
* Grouped function of Range, this is to avoid traverse of all elements of Range using
480+
* IterableLike's grouped function.
481+
*/
482+
def rangeGrouped(range: Range, size: Int): Seq[Range] = {
483+
val start = range.start
484+
val step = range.step
485+
val end = range.end
486+
for (i <- start.until(end, size * step)) yield {
487+
i.until(i + size * step, step)
488+
}
489+
}
490+
491+
/**
492+
* To equally divide n elements into m buckets, basically each bucket should have n/m elements,
493+
* for the remaining n%m elements, add one more element to the first n%m buckets each.
494+
*/
495+
def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
496+
val elementsPerBucket = numElements / numBuckets
497+
val remaining = numElements % numBuckets
498+
val splitPoint = (elementsPerBucket + 1) * remaining
499+
if (elementsPerBucket == 0) {
500+
rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
501+
} else {
502+
rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
503+
rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
504+
}
505+
}
506+
475507
/**
476508
* Return statistics about all of the outputs for a given shuffle.
477509
*/
478510
def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
479511
shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
480512
val totalSizes = new Array[Long](dep.partitioner.numPartitions)
481-
for (s <- statuses) {
482-
for (i <- 0 until totalSizes.length) {
483-
totalSizes(i) += s.getSizeForBlock(i)
513+
val parallelAggThreshold = conf.get(
514+
SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
515+
val parallelism = math.min(
516+
Runtime.getRuntime.availableProcessors(),
517+
statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
518+
if (parallelism <= 1) {
519+
for (s <- statuses) {
520+
for (i <- 0 until totalSizes.length) {
521+
totalSizes(i) += s.getSizeForBlock(i)
522+
}
523+
}
524+
} else {
525+
val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
526+
try {
527+
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
528+
val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
529+
reduceIds => Future {
530+
for (s <- statuses; i <- reduceIds) {
531+
totalSizes(i) += s.getSizeForBlock(i)
532+
}
533+
}
534+
}
535+
ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf)
536+
} finally {
537+
threadPool.shutdown()
484538
}
485539
}
486540
new MapOutputStatistics(dep.shuffleId, totalSizes)

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,4 +499,15 @@ package object config {
499499
"array in the sorter.")
500500
.intConf
501501
.createWithDefault(Integer.MAX_VALUE)
502+
503+
private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD =
504+
ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold")
505+
.internal()
506+
.doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " +
507+
"or equal to this threshold. Note that the actual parallelism is calculated by number of " +
508+
"mappers * shuffle partitions / this threshold + 1, so this threshold should be positive.")
509+
.intConf
510+
.checkValue(v => v > 0, "The threshold should be positive.")
511+
.createWithDefault(10000000)
512+
502513
}

core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
275275
}
276276
}
277277

278+
test("equally divide map statistics tasks") {
279+
val func = newTrackerMaster().equallyDivide _
280+
val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5))
281+
val expects = Seq(
282+
Seq(0, 0, 0, 0, 0),
283+
Seq(1, 1, 1, 1, 0),
284+
Seq(3, 3, 3, 3, 3),
285+
Seq(4, 3, 3, 3, 3),
286+
Seq(4, 4, 3, 3, 3),
287+
Seq(4, 4, 4, 3, 3),
288+
Seq(4, 4, 4, 4, 3),
289+
Seq(4, 4, 4, 4, 4))
290+
cases.zip(expects).foreach { case ((num, divisor), expect) =>
291+
val answer = func(num, divisor).toSeq
292+
var wholeSplit = (0 until num)
293+
answer.zip(expect).foreach { case (split, expectSplitLength) =>
294+
val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength)
295+
assert(currentSplit.toSet == split.toSet)
296+
wholeSplit = rest
297+
}
298+
}
299+
}
300+
278301
}

0 commit comments

Comments
 (0)