@@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
15561556 }
15571557}
15581558
1559- std::pair<SizeType32, std::optional <KVCacheBlock::IdType>> WindowBlockManager::storeBlocks (
1559+ std::pair<SizeType32, std::vector <KVCacheBlock::IdType>> WindowBlockManager::storeBlocks (
15601560 std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds, bool pinBlocks)
15611561{
15621562 SizeType32 numBlocksStoredForReuse = 0 ;
@@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
15691569
15701570 auto numBlocks = blockKeys.size ();
15711571 std::vector<BlockPtr> storedBlocks;
1572- std::optional <KVCacheBlock::IdType> lastStoredId = std:: nullopt ;
1572+ std::vector <KVCacheBlock::IdType> pinnedBlockIds ;
15731573 for (std::size_t blockCnt = 0 ; blockCnt < numBlocks; ++blockCnt)
15741574 {
15751575 auto const bid = blockIds[blockCnt];
@@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
16201620 if (pinBlocks)
16211621 {
16221622 searchRoot->incRefCount ();
1623+ pinnedBlockIds.push_back (searchRoot->getBlockId ());
16231624 }
1624- lastStoredId = searchRoot->getBlockId ();
16251625 }
16261626 if (mEventManager )
16271627 {
16281628 mEventManager ->enqueueStoredEvent (storedBlocks, mWindowSize );
16291629 }
1630- return {numBlocksStoredForReuse, lastStoredId };
1630+ return {numBlocksStoredForReuse, pinnedBlockIds };
16311631}
16321632
16331633void BlockManager::replaceSharedBlock (GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
@@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
17151715 return mEventManager ? mEventManager ->getEvents (timeout) : std::deque<tle::KVCacheEvent>{};
17161716}
17171717
1718- std::optional <KVCacheBlock::IdType> BlockManager::storeBlocksForReuse (
1718+ std::vector <KVCacheBlock::IdType> BlockManager::storeBlocksForReuse (
17191719 GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest, bool pinBlocks)
17201720{
1721- std::optional <KVCacheBlock::IdType> lastStoredId = std:: nullopt ;
1721+ std::vector <KVCacheBlock::IdType> pinnedBlockIds ;
17221722 for (auto & [_, manager] : mWindowBlockManagers )
17231723 {
1724- lastStoredId = manager.storeBlocksForReuse (sequence, llmRequest, pinBlocks);
1724+ pinnedBlockIds = manager.storeBlocksForReuse (sequence, llmRequest, pinBlocks);
17251725 }
1726- return lastStoredId ;
1726+ return pinnedBlockIds ;
17271727}
17281728
17291729std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks (
@@ -1767,15 +1767,15 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
17671767 }
17681768}
17691769
1770- void BlockManager::unpinBlocksById (KVCacheBlock::IdType blockId )
1770+ void BlockManager::unpinBlocksById (std::vector< KVCacheBlock::IdType> const & blockIds )
17711771{
17721772 // Use the first window size
17731773 if (mWindowBlockManagers .empty ())
17741774 {
17751775 return ;
17761776 }
17771777 auto & firstManager = mWindowBlockManagers .begin ()->second ;
1778- firstManager.unpinBlocksById (blockId );
1778+ firstManager.unpinBlocksById (blockIds );
17791779}
17801780
17811781void WindowBlockManager::pinBlocks (GenerationRequest& sequence)
@@ -1788,21 +1788,28 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
17881788 }
17891789}
17901790
1791- void WindowBlockManager::unpinBlocksById (KVCacheBlock::IdType blockId )
1791+ void WindowBlockManager::unpinBlocksById (std::vector< KVCacheBlock::IdType> const & blockIds )
17921792{
1793- if (blockId < 0 || static_cast < size_t >(blockId) >= mAllBlocksById . size ())
1793+ if (blockIds. empty ())
17941794 {
17951795 return ;
17961796 }
1797- auto block = mAllBlocksById [blockId];
1798- while (block && block-> getBlockId () != KVCacheBlock:: kCachedBlocksRootId )
1797+
1798+ for ( auto const & blockId : blockIds )
17991799 {
1800- block->decRefCount ();
1801- if (!block->hasRefs ())
1800+ if (blockId < 0 || static_cast <size_t >(blockId) >= mAllBlocksById .size ())
18021801 {
1803- mEvictionPolicy ->releaseBlock (block);
1802+ continue ;
1803+ }
1804+ auto block = mAllBlocksById [blockId];
1805+ if (block && block->getBlockId () != KVCacheBlock::kCachedBlocksRootId )
1806+ {
1807+ block->decRefCount ();
1808+ if (!block->hasRefs ())
1809+ {
1810+ mEvictionPolicy ->releaseBlock (block);
1811+ }
18041812 }
1805- block = std::move (block->getPrevBlock ());
18061813 }
18071814}
18081815
@@ -1870,7 +1877,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
18701877 (void ) storeBlocks (std::move (blockKeys), cacheBlockIds[beamIdx]);
18711878}
18721879
1873- std::optional <KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse (
1880+ std::vector <KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse (
18741881 GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest, bool pinBlocks)
18751882{
18761883 auto constexpr beamIdx = 0 ;
@@ -1883,7 +1890,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
18831890 auto const usableSize = static_cast <runtime::SizeType32>(uniqueTokens.size ()) - 1 ;
18841891 auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock , true );
18851892 auto blockKeys = buildBlockKeys (blockedUniqueTokens, *llmRequest);
1886- return storeBlocks (std::move (blockKeys), cacheBlockIds[beamIdx], pinBlocks).second ;
1893+
1894+ auto [numStored, pinnedBlockIds] = storeBlocks (std::move (blockKeys), cacheBlockIds[beamIdx], pinBlocks);
1895+
1896+ return pinnedBlockIds;
18871897}
18881898
18891899std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks (
@@ -1922,7 +1932,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
19221932 std::transform (allocatedBlocks.begin (), allocatedBlocks.end (), cacheBlockIds.begin (),
19231933 [](BlockPtr const & block) { return block->getBlockId (); });
19241934
1925- auto [numBlocksStoredForReuse, lastStoredId ] = storeBlocks (std::move (blockKeys), cacheBlockIds);
1935+ auto [numBlocksStoredForReuse, pinnedBlockIds ] = storeBlocks (std::move (blockKeys), cacheBlockIds);
19261936 TLLM_LOG_DEBUG (" %s::releaseBlocks Request %lu, %d blocks stored for reuse" , mLogPrefix .c_str (),
19271937 sequence.getRequestId (), numBlocksStoredForReuse);
19281938 }
@@ -2499,15 +2509,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
24992509 return lastStoredId;
25002510}
25012511
2502- std::optional <KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse (
2512+ std::vector <KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse (
25032513 RequestIdType requestId, OptionalRef<LlmRequest const > llmRequest, bool pinBlocks)
25042514{
25052515 TLLM_LOG_TRACE (" [%s]::%s start" , isCrossKv () ? " CROSS" : " SELF" , __PRETTY_FUNCTION__);
25062516 auto & sequence = getSequence (requestId);
2507- std::optional<KVCacheBlock::IdType> lastStoredId
2508- = mBlockManager .storeBlocksForReuse (sequence, llmRequest, pinBlocks);
2517+ auto pinnedBlockIds = mBlockManager .storeBlocksForReuse (sequence, llmRequest, pinBlocks);
25092518 TLLM_LOG_TRACE (" [%s]::%s stop" , isCrossKv () ? " CROSS" : " SELF" , __PRETTY_FUNCTION__);
2510- return lastStoredId ;
2519+ return pinnedBlockIds ;
25112520}
25122521
25132522void KVCacheManager::schedulingRemoveSequence (RequestIdType requestId)
@@ -2522,9 +2531,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
25222531 mBlockManager .pinBlocks (sequence);
25232532}
25242533
2525- void KVCacheManager::unpinBlocksById (KVCacheBlock::IdType blockId )
2534+ void KVCacheManager::unpinBlocksById (std::vector< KVCacheBlock::IdType> const & blockIds )
25262535{
2527- mBlockManager .unpinBlocksById (blockId );
2536+ mBlockManager .unpinBlocksById (blockIds );
25282537}
25292538
25302539SizeType32 KVCacheManager::copyBlockOffsets (ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
0 commit comments