Skip to content

Commit 43f7936

Browse files
zml1206cloud-fan
authored andcommitted
[SPARK-53446][CORE] Optimize BlockManager remove operations with cached block mappings
### What changes were proposed in this pull request? Continue #52210. Introduced three concurrent hash maps to track block ID associations for optimize BlockManager remove operations by introducing cached mappings to eliminate O(n) linear scans. ### Why are the changes needed? Previously, removeRdd(), removeBroadcast(), and removeCache() required scanning all blocks in blockInfoManager.entries to find matches. This approach becomes a serious bottleneck when: 1. Large block counts: In production deployments with millions or even tens of millions of cached blocks, linear scans can be prohibitively slow 2. High cleanup frequency: Workloads that repeatedly create and discard RDDs or broadcast variables accumulate overhead quickly The original removeRdd() method already contained a TODO noting that an additional mapping would be needed to avoid linear scans. This PR implements that improvement. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52646 from zml1206/SPARK-53446. Authored-by: zml1206 <zhuml1206@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 8446b6d commit 43f7936

File tree

3 files changed

+116
-8
lines changed

3 files changed

+116
-8
lines changed

core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
149149
*/
150150
private[this] val blockInfoWrappers = new ConcurrentHashMap[BlockId, BlockInfoWrapper]
151151

152+
// Cache mappings to avoid O(n) scans in remove operations.
153+
private[this] val rddToBlockIds =
154+
new ConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
155+
private[this] val broadcastToBlockIds =
156+
new ConcurrentHashMap[Long, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
157+
private[this] val sessionToBlockIds =
158+
new ConcurrentHashMap[String, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
159+
152160
/**
153161
* Record invisible rdd blocks stored in the block manager, entries will be removed when blocks
154162
* are marked as visible or blocks are removed by [[removeBlock()]].
@@ -445,6 +453,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
445453
}
446454

447455
if (previous == null) {
456+
addToMapping(blockId)
448457
// New block lock it for writing.
449458
val result = lockForWriting(blockId, blocking = false)
450459
assert(result.isDefined)
@@ -535,6 +544,23 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
535544
blockInfoWrappers.entrySet().iterator().asScala.map(kv => kv.getKey -> kv.getValue.info)
536545
}
537546

547+
/**
548+
* Return all blocks belonging to the given RDD.
549+
*/
550+
def rddBlockIds(rddId: Int): Seq[BlockId] = getBlockIdsFromMapping(rddToBlockIds, rddId)
551+
552+
/**
553+
* Return all blocks belonging to the given broadcast.
554+
*/
555+
def broadcastBlockIds(broadcastId: Long): Seq[BlockId] =
556+
getBlockIdsFromMapping(broadcastToBlockIds, broadcastId)
557+
558+
/**
559+
* Return cache blocks that might be related to cached local relations.
560+
*/
561+
def sessionBlockIds(sessionUUID: String): Seq[BlockId] =
562+
getBlockIdsFromMapping(sessionToBlockIds, sessionUUID)
563+
538564
/**
539565
* Removes the given block and releases the write lock on it.
540566
*
@@ -551,6 +577,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
551577
} else {
552578
invisibleRDDBlocks.synchronized {
553579
blockInfoWrappers.remove(blockId)
580+
removeFromMapping(blockId)
554581
blockId.asRDDId.foreach(invisibleRDDBlocks.remove)
555582
}
556583
info.readerCount = 0
@@ -573,11 +600,75 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
573600
}
574601
}
575602
blockInfoWrappers.clear()
603+
rddToBlockIds.clear()
604+
broadcastToBlockIds.clear()
605+
sessionToBlockIds.clear()
576606
readLocksByTask.clear()
577607
writeLocksByTask.clear()
578608
invisibleRDDBlocks.synchronized {
579609
invisibleRDDBlocks.clear()
580610
}
581611
}
582612

613+
/**
614+
* Return all blocks in the cache mapping for a given key.
615+
*/
616+
private def getBlockIdsFromMapping[K](
617+
map: ConcurrentHashMap[K, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]],
618+
key: K): Seq[BlockId] = {
619+
Option(map.get(key)).map(_.asScala.toSeq).getOrElse(Seq.empty)
620+
}
621+
622+
/**
623+
* Add a block ID to the corresponding cache mapping based on its type.
624+
*/
625+
private def addToMapping(blockId: BlockId): Unit = {
626+
blockId match {
627+
case rddBlockId: RDDBlockId =>
628+
rddToBlockIds
629+
.computeIfAbsent(rddBlockId.rddId, _ => ConcurrentHashMap.newKeySet())
630+
.add(blockId)
631+
case broadcastBlockId: BroadcastBlockId =>
632+
broadcastToBlockIds
633+
.computeIfAbsent(broadcastBlockId.broadcastId, _ => ConcurrentHashMap.newKeySet())
634+
.add(blockId)
635+
case cacheId: CacheId =>
636+
sessionToBlockIds
637+
.computeIfAbsent(cacheId.sessionUUID, _ => ConcurrentHashMap.newKeySet())
638+
.add(blockId)
639+
case _ => // Do nothing for other block types
640+
}
641+
}
642+
643+
/**
644+
* Remove a block ID from the corresponding cache mapping based on its type.
645+
*/
646+
private def removeFromMapping(blockId: BlockId): Unit = {
647+
def doRemove[K](
648+
map: ConcurrentHashMap[K, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]],
649+
key: K,
650+
block: BlockId): Unit = {
651+
map.compute(key,
652+
(_, set) => {
653+
if (null != set) {
654+
set.remove(block)
655+
if (set.isEmpty) null else set
656+
} else {
657+
// missing
658+
null
659+
}
660+
}
661+
)
662+
}
663+
664+
blockId match {
665+
case rddBlockId: RDDBlockId =>
666+
doRemove(rddToBlockIds, rddBlockId.rddId, rddBlockId)
667+
case broadcastBlockId: BroadcastBlockId =>
668+
doRemove(broadcastToBlockIds, broadcastBlockId.broadcastId, broadcastBlockId)
669+
case cacheId: CacheId =>
670+
doRemove(sessionToBlockIds, cacheId.sessionUUID, cacheId)
671+
case _ => // Do nothing for other block types
672+
}
673+
}
583674
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,9 +2054,8 @@ private[spark] class BlockManager(
20542054
* @return The number of blocks removed.
20552055
*/
20562056
def removeRdd(rddId: Int): Int = {
2057-
// TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
20582057
logInfo(log"Removing RDD ${MDC(RDD_ID, rddId)}")
2059-
val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId)
2058+
val blocksToRemove = blockInfoManager.rddBlockIds(rddId)
20602059
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
20612060
blocksToRemove.size
20622061
}
@@ -2090,9 +2089,7 @@ private[spark] class BlockManager(
20902089
*/
20912090
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
20922091
logDebug(s"Removing broadcast $broadcastId")
2093-
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
2094-
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
2095-
}
2092+
val blocksToRemove = blockInfoManager.broadcastBlockIds(broadcastId)
20962093
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
20972094
blocksToRemove.size
20982095
}
@@ -2104,9 +2101,7 @@ private[spark] class BlockManager(
21042101
*/
21052102
def removeCache(sessionUUID: String): Int = {
21062103
logDebug(s"Removing cache of spark session with UUID: $sessionUUID")
2107-
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
2108-
case cid: CacheId if cid.sessionUUID == sessionUUID => cid
2109-
}
2104+
val blocksToRemove = blockInfoManager.sessionBlockIds(sessionUUID)
21102105
blocksToRemove.foreach { blockId => removeBlock(blockId) }
21112106
blocksToRemove.size
21122107
}

core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,6 +2634,28 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
26342634
assert(logBlockIds.contains(logBlockId1) && logBlockIds.contains(logBlockId2))
26352635
}
26362636

2637+
test("SPARK-53446: Optimize BlockManager remove operations with cached block mappings") {
2638+
val store = makeBlockManager(8000, "executor1")
2639+
val broadcastId = 0
2640+
val rddId = 1
2641+
val sessionId = UUID.randomUUID.toString
2642+
val data = new Array[Byte](100)
2643+
2644+
store.putSingle(BroadcastBlockId(broadcastId), data, StorageLevel.MEMORY_ONLY)
2645+
assert(store.blockInfoManager.broadcastBlockIds(broadcastId).nonEmpty)
2646+
store.putSingle(rdd(rddId, 3), data, StorageLevel.MEMORY_ONLY)
2647+
assert(store.blockInfoManager.rddBlockIds(rddId).nonEmpty)
2648+
store.putSingle(CacheId(sessionId, "abc"), data, StorageLevel.MEMORY_ONLY)
2649+
assert(store.blockInfoManager.sessionBlockIds(sessionId).nonEmpty)
2650+
2651+
store.removeBroadcast(broadcastId, false)
2652+
assert(store.blockInfoManager.broadcastBlockIds(broadcastId).isEmpty)
2653+
store.removeRdd(rddId)
2654+
assert(store.blockInfoManager.rddBlockIds(rddId).isEmpty)
2655+
store.removeCache(sessionId)
2656+
assert(store.blockInfoManager.sessionBlockIds(sessionId).isEmpty)
2657+
}
2658+
26372659
private def createKryoSerializerWithDiskCorruptedInputStream(): KryoSerializer = {
26382660
class TestDiskCorruptedInputStream extends InputStream {
26392661
override def read(): Int = throw new IOException("Input/output error")

0 commit comments

Comments
 (0)