diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 69250e2475732..8cde8759e0597 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -93,7 +93,51 @@ private[storage] class BlockInfo( checkInvariants() } -private class BlockInfoWrapper( +/** + * Group of blocks that share some common trait (e.g. same broadcast, or same RDD). If there is no + * obvious grouping, a block has its own unique group. The benefit of grouping blocks is that we can + * do group level operations. This is especially useful when we need to clean-up groups. + */ +private[storage] abstract class BlockInfoGroup { + /** + * Get all the Block to Info entries contained by this group. + * + * This method is thread safe. However when you want to modify a [[BlockInfo]] instance, you + * need to take out the lock for the block it maps to. + */ + def infos: Seq[(BlockId, BlockInfoWrapper)] + + /** + * Get the number of blocks inside this group. + * + * This method is thread safe. + */ + def size: Int + + /** + * Get the (wrapped) [[BlockInfo]] associated with the `blockId`.* + */ + def get(blockId: BlockId): Option[BlockInfoWrapper] + + /** + * Clear all entries from this group. + */ + def clear(): Unit + + /** + * Associate `blockInfo` with `blockId` if no mapping exists. This method returns `true` if the + * mapping was added, `false` otherwise. + */ + def putIfAbsent(blockId: BlockId, blockInfo: BlockInfoWrapper): Boolean + + /** + * Remove `blockId` from this group. This method returns `true` when the group is empty after + * removal, `false` otherwise. + */ + def remove(blockId: BlockId): Boolean +} + +class BlockInfoWrapper( val info: BlockInfo, private val lock: Lock, private val condition: Condition) { @@ -115,6 +159,78 @@ private class BlockInfoWrapper( } } +object BlockInfoGroup { + /** + * Group of zero or one blocks. This is for blocks for which we do not have a natural mapping. + */ + class Singleton extends BlockInfoGroup { + private var blockId: BlockId = _ + private var info: BlockInfoWrapper = _ + override def size: Int = { + if (blockId != null) 1 + else 0 + } + override def infos: Seq[(BlockId, BlockInfoWrapper)] = { + if (blockId != null) blockId -> info :: Nil + else Nil + } + override def clear(): Unit = { + blockId = null + info = null + } + override def get(blockIdToGet: BlockId): Option[BlockInfoWrapper] = { + if (blockId == blockIdToGet) Option(info) + else None + } + override def putIfAbsent(blockIdToAdd: BlockId, infoToAdd: BlockInfoWrapper): Boolean = { + if (blockId == null) { + blockId = blockIdToAdd + info = infoToAdd + true + } else { + false + } + } + override def remove(blockIdToRemove: BlockId): Boolean = { + if (blockId == blockIdToRemove) { + blockId = null + info = null + } + blockId == null + } + } + + /** + * Group for 0..n blocks. This for blocks that have a natural grouping. + */ + class Collection extends BlockInfoGroup { + private val infoMap = new ConcurrentHashMap[BlockId, BlockInfoWrapper]() + + override def infos: Seq[(BlockId, BlockInfoWrapper)] = { + infoMap.asScala.toSeq + } + + override def size: Int = infoMap.size() + + override def get(blockId: BlockId): Option[BlockInfoWrapper] = { + Option(infoMap.get(blockId)) + } + + override def clear(): Unit = { + infoMap.clear() + } + + override def putIfAbsent(blockId: BlockId, blockInfo: BlockInfoWrapper): Boolean = { + infoMap.putIfAbsent(blockId, blockInfo) == null + } + + override def remove(blockId: BlockId): Boolean = { + infoMap.remove(blockId) + infoMap.isEmpty + } + } +} + private[storage] object BlockInfo { /** @@ -147,7 +263,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed * by [[removeBlock()]]. */ - private[this] val blockInfoWrappers = new ConcurrentHashMap[BlockId, BlockInfoWrapper] + private[this] val blockInfoGroups = new ConcurrentHashMap[BlockId, BlockInfoGroup] // Cache mappings to avoid O(n) scans in remove operations. private[this] val rddToBlockIds = @@ -179,6 +295,19 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false */ private[this] val writeLocksByTask = new ConcurrentHashMap[TaskAttemptId, util.Set[BlockId]] + /** Get the group id that `blockId` belongs to. */ + protected def getGroupId(blockId: BlockId): BlockId = blockId match { + case BroadcastBlockId(broadcastId, _) => BroadcastBlockId(broadcastId, "group") + case RDDBlockId(rddId, _) => RDDBlockId(rddId, -1) + case CacheId(cacheId, _) => CacheId(cacheId, "group") + case _ => blockId + } + + def createBlockInfoGroup(groupId: BlockId): BlockInfoGroup = groupId match { + case _: BroadcastBlockId | _: RDDBlockId | _: CacheId => new BlockInfoGroup.Collection + case _ => new BlockInfoGroup.Singleton + } + /** * Tracks the set of blocks that each task has locked for reading, along with the number of times * that a block has been locked (since our read locks are re-entrant). @@ -203,7 +332,10 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false private[spark] def isRDDBlockVisible(blockId: RDDBlockId): Boolean = { if (trackingCacheVisibility) { invisibleRDDBlocks.synchronized { - blockInfoWrappers.containsKey(blockId) && !invisibleRDDBlocks.contains(blockId) + val groupId = getGroupId(blockId) + val group = blockInfoGroups.get(groupId) + val blockExists = Option(group).exists(g => g.get(blockId).isDefined) + blockExists && !invisibleRDDBlocks.contains(blockId) } } else { // Always be visible if the feature flag is disabled. @@ -245,20 +377,26 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false f: BlockInfo => Boolean): Option[BlockInfo] = { var done = false var result: Option[BlockInfo] = None + val groupId = getGroupId(blockId) while (!done) { - val wrapper = blockInfoWrappers.get(blockId) - if (wrapper == null) { + val group = blockInfoGroups.get(groupId) + if (group == null) { done = true } else { - wrapper.withLock { (info, condition) => - if (f(info)) { - result = Some(info) - done = true - } else if (!blocking) { + group.get(blockId) match { + case Some(wrapper) => + wrapper.withLock { (info, condition) => + if (f(info)) { + result = Some(info) + done = true + } else if (!blocking) { + done = true + } else { + condition.await() + } + } + case None => done = true - } else { - condition.await() - } } } } @@ -271,11 +409,14 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false * was not registered, an error will be thrown. */ private def blockInfo[T](blockId: BlockId)(f: (BlockInfo, Condition) => T): T = { - val wrapper = blockInfoWrappers.get(blockId) - if (wrapper == null) { + val group = blockInfoGroups.get(getGroupId(blockId)) + if (group == null) { throw SparkCoreErrors.blockDoesNotExistError(blockId) } - wrapper.withLock(f) + group.get(blockId) match { + case Some(wrapper) => wrapper.withLock(f) + case None => throw SparkCoreErrors.blockDoesNotExistError(blockId) + } } /** @@ -360,12 +501,11 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false * [[BlockManager.getStatus()]] and should not be called by other code outside of this class. */ private[storage] def get(blockId: BlockId): Option[BlockInfo] = { - val wrapper = blockInfoWrappers.get(blockId) - if (wrapper != null) { - Some(wrapper.info) - } else { - None + val group = blockInfoGroups.get(getGroupId(blockId)) + if (group == null) { + return None } + group.get(blockId).map(_.info) } /** @@ -438,25 +578,21 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false // duration of this operation. This way we prevent race conditions when two threads try to write // the same block at the same time. val lock = locks.get(blockId) + val groupId = getGroupId(blockId) lock.lock() try { val wrapper = new BlockInfoWrapper(newBlockInfo, lock) while (true) { - val previous = if (trackingCacheVisibility) { - invisibleRDDBlocks.synchronized { - val res = blockInfoWrappers.putIfAbsent(blockId, wrapper) - if (res == null) { + val group = blockInfoGroups.computeIfAbsent(groupId, _ => { + createBlockInfoGroup(groupId) + }) + if (group.putIfAbsent(blockId, wrapper)) { + if (trackingCacheVisibility) { + invisibleRDDBlocks.synchronized { // Added to invisible blocks if it doesn't exist before. blockId.asRDDId.foreach(invisibleRDDBlocks.add) } - res } - } else { - blockInfoWrappers.putIfAbsent(blockId, wrapper) - } - - if (previous == null) { - addToMapping(blockId) // New block lock it for writing. val result = lockForWriting(blockId, blocking = false) assert(result.isDefined) @@ -479,6 +615,17 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false } } + /** + * Get all the blocks that are currently tracked for a group. + */ + def getBlockIdsForGroup(blockId: BlockId): Seq[BlockId] = { + val group = blockInfoGroups.get(getGroupId(blockId)) + if (group == null) { + return Nil + } + group.infos.map(_._1) + } + /** * Release all lock held by the given task, clearing that task's pin bookkeeping * structures and updating the global pin counts. This method should be called at the @@ -534,7 +681,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false /** * Returns the number of blocks tracked. */ - def size: Int = blockInfoWrappers.size + def size: Int = blockInfoGroups.values().asScala.iterator.map(_.size).sum /** * Return the number of map entries in this pin counter's internal data structures. @@ -554,26 +701,9 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false * is being traversed. */ def entries: Iterator[(BlockId, BlockInfo)] = { - blockInfoWrappers.entrySet().iterator().asScala.map(kv => kv.getKey -> kv.getValue.info) + blockInfoGroups.values().asScala.flatMap(_.infos.map(kv => kv._1 -> kv._2.info)).iterator } - /** - * Return all blocks belonging to the given RDD. - */ - def rddBlockIds(rddId: Int): Seq[BlockId] = getBlockIdsFromMapping(rddToBlockIds, rddId) - - /** - * Return all blocks belonging to the given broadcast. - */ - def broadcastBlockIds(broadcastId: Long): Seq[BlockId] = - getBlockIdsFromMapping(broadcastToBlockIds, broadcastId) - - /** - * Return cache blocks that might be related to cached local relations. - */ - def sessionBlockIds(sessionUUID: String): Seq[BlockId] = - getBlockIdsFromMapping(sessionToBlockIds, sessionUUID) - /** * Removes the given block and releases the write lock on it. * @@ -582,6 +712,11 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false def removeBlock(blockId: BlockId): Unit = { val taskAttemptId = currentTaskAttemptId logTrace(s"Task $taskAttemptId trying to remove block $blockId") + val groupId = getGroupId(blockId) + val group = blockInfoGroups.get(groupId) + if (group == null) { + throw SparkCoreErrors.blockDoesNotExistError(blockId) + } blockInfo(blockId) { (info, condition) => if (info.writerTask != taskAttemptId) { throw SparkException.internalError( @@ -589,10 +724,11 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false category = "STORAGE") } else { invisibleRDDBlocks.synchronized { - blockInfoWrappers.remove(blockId) - removeFromMapping(blockId) blockId.asRDDId.foreach(invisibleRDDBlocks.remove) } + if (group.remove(blockId)) { + blockInfoGroups.remove(groupId) + } info.readerCount = 0 info.writerTask = BlockInfo.NO_WRITER writeLocksByTask.get(taskAttemptId).remove(blockId) @@ -605,14 +741,18 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false * Delete all state. Called during shutdown. */ def clear(): Unit = { - blockInfoWrappers.values().forEach { wrapper => - wrapper.tryLock { (info, condition) => - info.readerCount = 0 - info.writerTask = BlockInfo.NO_WRITER - condition.signalAll() + blockInfoGroups.values().forEach { group => + group.infos.foreach { + case (_, wrapper) => + wrapper.tryLock { (info, condition) => + info.readerCount = 0 + info.writerTask = BlockInfo.NO_WRITER + condition.signalAll() + } } + group.clear() } - blockInfoWrappers.clear() + blockInfoGroups.clear() rddToBlockIds.clear() broadcastToBlockIds.clear() sessionToBlockIds.clear() @@ -622,66 +762,4 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false invisibleRDDBlocks.clear() } } - - /** - * Return all blocks in the cache mapping for a given key. - */ - private def getBlockIdsFromMapping[K]( - map: ConcurrentHashMap[K, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]], - key: K): Seq[BlockId] = { - Option(map.get(key)).map(_.asScala.toSeq).getOrElse(Seq.empty) - } - - /** - * Add a block ID to the corresponding cache mapping based on its type. - */ - private def addToMapping(blockId: BlockId): Unit = { - blockId match { - case rddBlockId: RDDBlockId => - rddToBlockIds - .computeIfAbsent(rddBlockId.rddId, _ => ConcurrentHashMap.newKeySet()) - .add(blockId) - case broadcastBlockId: BroadcastBlockId => - broadcastToBlockIds - .computeIfAbsent(broadcastBlockId.broadcastId, _ => ConcurrentHashMap.newKeySet()) - .add(blockId) - case cacheId: CacheId => - sessionToBlockIds - .computeIfAbsent(cacheId.sessionUUID, _ => ConcurrentHashMap.newKeySet()) - .add(blockId) - case _ => // Do nothing for other block types - } - } - - /** - * Remove a block ID from the corresponding cache mapping based on its type. - */ - private def removeFromMapping(blockId: BlockId): Unit = { - def doRemove[K]( - map: ConcurrentHashMap[K, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]], - key: K, - block: BlockId): Unit = { - map.compute(key, - (_, set) => { - if (null != set) { - set.remove(block) - if (set.isEmpty) null else set - } else { - // missing - null - } - } - ) - } - - blockId match { - case rddBlockId: RDDBlockId => - doRemove(rddToBlockIds, rddBlockId.rddId, rddBlockId) - case broadcastBlockId: BroadcastBlockId => - doRemove(broadcastToBlockIds, broadcastBlockId.broadcastId, broadcastBlockId) - case cacheId: CacheId => - doRemove(sessionToBlockIds, cacheId.sessionUUID, cacheId) - case _ => // Do nothing for other block types - } - } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0d675a3abd120..315fa8081bb60 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -2048,6 +2048,15 @@ private[spark] class BlockManager( status.storageLevel } + /** + * Remove all blocks belonging to a group. + */ + private[spark] def removeBlockGroup(groupId: BlockId, tellMaster: Boolean): Int = { + val blocksToRemove = blockInfoManager.getBlockIdsForGroup(groupId) + blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster)) + blocksToRemove.size + } + /** * Remove all blocks belonging to the given RDD. * @@ -2055,9 +2064,7 @@ private[spark] class BlockManager( */ def removeRdd(rddId: Int): Int = { logInfo(log"Removing RDD ${MDC(RDD_ID, rddId)}") - val blocksToRemove = blockInfoManager.rddBlockIds(rddId) - blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } - blocksToRemove.size + removeBlockGroup(RDDBlockId(rddId, -1), tellMaster = false) } def decommissionBlockManager(): Unit = storageEndpoint.ask(DecommissionBlockManager) @@ -2089,9 +2096,7 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfoManager.broadcastBlockIds(broadcastId) - blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } - blocksToRemove.size + removeBlockGroup(BroadcastBlockId(broadcastId), tellMaster) } /** @@ -2101,9 +2106,7 @@ private[spark] class BlockManager( */ def removeCache(sessionUUID: String): Int = { logDebug(s"Removing cache of spark session with UUID: $sessionUUID") - val blocksToRemove = blockInfoManager.sessionBlockIds(sessionUUID) - blocksToRemove.foreach { blockId => removeBlock(blockId) } - blocksToRemove.size + removeBlockGroup(CacheId(sessionUUID, "group"), tellMaster = false) } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f4bb5b7cf7cb8..597b2942841cd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -2634,26 +2634,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe assert(logBlockIds.contains(logBlockId1) && logBlockIds.contains(logBlockId2)) } - test("SPARK-53446: Optimize BlockManager remove operations with cached block mappings") { + test( + "SPARK-53446/SPARK-54947: Optimize BlockManager remove operations with cached block mappings") { val store = makeBlockManager(8000, "executor1") - val broadcastId = 0 - val rddId = 1 - val sessionId = UUID.randomUUID.toString + val broadcastId = BroadcastBlockId(0) + val rddId = RDDBlockId(1, 3) + val cacheId = CacheId(UUID.randomUUID.toString, "abc") val data = new Array[Byte](100) - store.putSingle(BroadcastBlockId(broadcastId), data, StorageLevel.MEMORY_ONLY) - assert(store.blockInfoManager.broadcastBlockIds(broadcastId).nonEmpty) - store.putSingle(rdd(rddId, 3), data, StorageLevel.MEMORY_ONLY) - assert(store.blockInfoManager.rddBlockIds(rddId).nonEmpty) - store.putSingle(CacheId(sessionId, "abc"), data, StorageLevel.MEMORY_ONLY) - assert(store.blockInfoManager.sessionBlockIds(sessionId).nonEmpty) - - store.removeBroadcast(broadcastId, false) - assert(store.blockInfoManager.broadcastBlockIds(broadcastId).isEmpty) - store.removeRdd(rddId) - assert(store.blockInfoManager.rddBlockIds(rddId).isEmpty) - store.removeCache(sessionId) - assert(store.blockInfoManager.sessionBlockIds(sessionId).isEmpty) + store.putSingle(broadcastId, data, StorageLevel.MEMORY_ONLY) + assert(store.blockInfoManager.getBlockIdsForGroup(broadcastId).nonEmpty) + store.putSingle(rddId, data, StorageLevel.MEMORY_ONLY) + assert(store.blockInfoManager.getBlockIdsForGroup(rddId).nonEmpty) + store.putSingle(cacheId, data, StorageLevel.MEMORY_ONLY) + assert(store.blockInfoManager.getBlockIdsForGroup(cacheId).nonEmpty) + + store.removeBroadcast(broadcastId.broadcastId, false) + assert(store.blockInfoManager.getBlockIdsForGroup(broadcastId).isEmpty) + store.removeRdd(rddId.rddId) + assert(store.blockInfoManager.getBlockIdsForGroup(rddId).isEmpty) + store.removeCache(cacheId.sessionUUID) + assert(store.blockInfoManager.getBlockIdsForGroup(cacheId).isEmpty) } private def createKryoSerializerWithDiskCorruptedInputStream(): KryoSerializer = {