Skip to content

Commit aea5f50

Browse files
Liupengchengcloud-fan
authored andcommitted
[SPARK-26525][SHUFFLE] Fast release ShuffleBlockFetcherIterator on completion of the iteration
## What changes were proposed in this pull request? Currently, spark would not release ShuffleBlockFetcherIterator until the whole task finished.In some conditions, it incurs memory leak. An example is `rdd.repartition(m).coalesce(n, shuffle = false).save`, each `ShuffleBlockFetcherIterator` contains some metas about mapStatus(`blocksByAddress`) and each resultTask will keep n(max to shuffle partitions) shuffleBlockFetcherIterator and the memory would never released until the task completion, for they are referenced by the completion callbacks of TaskContext. In some case, it may take huge memory and incurs OOM. Actually, We can release ShuffleBlockFetcherIterator as soon as it's consumed. This PR is to resolve this problem. ## How was this patch tested? unittest Please review http://spark.apache.org/contributing.html before opening a pull request. Closes apache#23438 from liupc/Fast-release-shuffleblockfetcheriterator. Lead-authored-by: Liupengcheng <[email protected]> Co-authored-by: liupengcheng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 8f968b4 commit aea5f50

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
5555
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
5656
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
5757
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
58-
readMetrics)
58+
readMetrics).toCompletionIterator
5959

6060
val serializerInstance = dep.serializer.newInstance()
6161

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3131
import org.apache.spark.network.shuffle._
3232
import org.apache.spark.network.util.TransportConf
3333
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
34-
import org.apache.spark.util.Utils
34+
import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
3535
import org.apache.spark.util.io.ChunkedByteBufferOutputStream
3636

3737
/**
@@ -160,6 +160,8 @@ final class ShuffleBlockFetcherIterator(
160160
@GuardedBy("this")
161161
private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
162162

163+
private[this] val onCompleteCallback = new ShuffleFetchCompletionListener(this)
164+
163165
initialize()
164166

165167
// Decrements the buffer reference count.
@@ -192,7 +194,7 @@ final class ShuffleBlockFetcherIterator(
192194
/**
193195
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
194196
*/
195-
private[this] def cleanup() {
197+
private[storage] def cleanup() {
196198
synchronized {
197199
isZombie = true
198200
}
@@ -364,7 +366,7 @@ final class ShuffleBlockFetcherIterator(
364366

365367
private[this] def initialize(): Unit = {
366368
// Add a task completion callback (called in both success case and failure case) to cleanup.
367-
context.addTaskCompletionListener[Unit](_ => cleanup())
369+
context.addTaskCompletionListener(onCompleteCallback)
368370

369371
// Split local and remote blocks.
370372
val remoteRequests = splitLocalRemoteBlocks()
@@ -509,6 +511,11 @@ final class ShuffleBlockFetcherIterator(
509511
(currentResult.blockId, new BufferReleasingInputStream(input, this))
510512
}
511513

514+
def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
515+
CompletionIterator[(BlockId, InputStream), this.type](this,
516+
onCompleteCallback.onComplete(context))
517+
}
518+
512519
private def fetchUpToMaxBytes(): Unit = {
513520
// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
514521
// immediately, defer the request until the next time it can be processed.
@@ -609,6 +616,27 @@ private class BufferReleasingInputStream(
609616
override def reset(): Unit = delegate.reset()
610617
}
611618

619+
/**
620+
* A listener to be called at the completion of the ShuffleBlockFetcherIterator
621+
* @param data the ShuffleBlockFetcherIterator to process
622+
*/
623+
private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterator)
624+
extends TaskCompletionListener {
625+
626+
override def onTaskCompletion(context: TaskContext): Unit = {
627+
if (data != null) {
628+
data.cleanup()
629+
// Null out the referent here to make sure we don't keep a reference to this
630+
// ShuffleBlockFetcherIterator, after we're done reading from it, to let it be
631+
// collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress)
632+
data = null
633+
}
634+
}
635+
636+
// Just an alias for onTaskCompletion to avoid confusing
637+
def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context)
638+
}
639+
612640
private[storage]
613641
object ShuffleBlockFetcherIterator {
614642

0 commit comments

Comments
 (0)