diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 62fc4fcb301..94717307b64 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -288,6 +288,9 @@ class KVCacheBlock void removeNextBlock(BlockKey const& blockKey); + void freeDescendantsRecursively(); + void freeBlockAndAllDescendants(); + //! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of //! blockKey. //! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index c0bf858cc9e..1ce9a08a91e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -477,6 +477,32 @@ void KVCacheBlock::removeNextBlock(BlockKey const& blockKey) mNextBlocks.erase(blockKey); } +void KVCacheBlock::freeDescendantsRecursively() +{ + bool hasChildren = !mNextBlocks.empty(); + if (hasChildren) + { + for (auto it = mNextBlocks.begin(); it != mNextBlocks.end();) + { + it->second->freeDescendantsRecursively(); + TLLM_LOG_DEBUG("KVCacheBlock::freeDescendantsRecursively - Freeing block %d", it->second->getBlockId()); + it = mNextBlocks.erase(it); + } + } + mPrevBlock = nullptr; +} + +void KVCacheBlock::freeBlockAndAllDescendants() +{ + // free from previous block + if (mPrevBlock != nullptr) + { + mPrevBlock->removeNextBlock(mBlockKey); + mPrevBlock = nullptr; + } + freeDescendantsRecursively(); +} + bool KVCacheBlock::isFull() const { return mIsFull; @@ -956,19 +982,14 @@ void WindowBlockManager::freeLeafBlock(BlockPtr const& block) void WindowBlockManager::freeChildren(BlockPtr const& block) { - // Free all descendants of block - for (auto const& p : block->getNextBlocks()) - { - auto childBlock = p.second; - freeChildren(childBlock); - } - - // Free block + // Tell event manager we are freeing block if (mEventManager && blockInRadixTree(block)) { mEventManager->enqueueRemovedEvent(block, mWindowSize); } - freeLeafBlock(block); + + // Free block and all it's descendants from radix tree + block->freeBlockAndAllDescendants(); } BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority, @@ -1567,60 +1588,80 @@ std::pair> WindowBlockManager::sto auto searchRoot = mCachedBlocksRoot; bool needMatch = true; - auto numBlocks = blockKeys.size(); + // There is no guarantee that these vectors will be the same length. + // Only iterate as long as we have valid blockKey and blockId. + auto numBlocks = std::min(blockKeys.size(), blockIds.size()); std::vector storedBlocks; std::vector pinnedBlockIds; for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) { - auto const bid = blockIds[blockCnt]; - TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid); - auto& block = mAllBlocksById[bid]; - auto const& blockKey = blockKeys[blockCnt]; - - auto [partialMatch, numMatched, matchedBlock] - = needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr); - if (matchedBlock != nullptr) - { - // Found match - TLLM_LOG_DEBUG( - "%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId()); - searchRoot = matchedBlock; - // TODO possible optimization: if bid != matchedBlock->getBlockId(), - // block can be freed and inserted at mFreePrimaryBlocks.begin() - } - else - { - // No match - TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(), - block->getBlockId()); - TLLM_CHECK_WITH_INFO(block->getBlockId() == bid, - "Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid)); - needMatch = false; // no matching needed for following blocks - block->setBlockKey(blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock); - block->setPrevBlock(searchRoot); - block->setPrevBlockInSeq(searchRoot); - searchRoot->addNextBlock(blockKey, block); - - // Sanity check. The list of stored blocks should be connected. - TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back()); - - storedBlocks.push_back(block); - TLLM_CHECK(block->getPrevBlockInSeq() == nullptr - || block->getPrevBlockInSeq()->getHash() == searchRoot->getHash()); - auto oldHash = block->getHash(); - auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash()); - if (oldHash != newHash) + try + { + // Protect against blockIds being shorter than blockKeys. + auto const bid = blockIds.at(blockCnt); + TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid); + // We set blockId to an invalid value to indicate that a block has been released early for a limited + // attention layer. Make sure we don't store an invalid block because of this. + auto& block = mAllBlocksById.at(bid); + // Protect against blockKeys being shorter than blockIds. + auto const& blockKey = blockKeys.at(blockCnt); + + // If either of the above error conditions occur, std::vector::at will throw an exception, which is caught + // further down. This will prevent an invalid block from being stored for reuse. The catch clause exits loop + // early, preventing blocks following an invalid block from being reused. + + auto [partialMatch, numMatched, matchedBlock] = needMatch + ? searchRoot->findMatchingBlock(blockKey, false, false) + : std::make_tuple(false, 0, nullptr); + if (matchedBlock != nullptr) + { + // Found match + TLLM_LOG_DEBUG("%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), + matchedBlock->getBlockId()); + searchRoot = matchedBlock; + // TODO possible optimization: if bid != matchedBlock->getBlockId(), + // block can be freed and inserted at mFreePrimaryBlocks.begin() + } + else + { + // No match + TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", + mLogPrefix.c_str(), block->getBlockId()); + TLLM_CHECK_WITH_INFO(block->getBlockId() == bid, + "Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid)); + needMatch = false; // no matching needed for following blocks + block->setBlockKey(blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock); + block->setPrevBlock(searchRoot); + block->setPrevBlockInSeq(searchRoot); + searchRoot->addNextBlock(blockKey, block); + + // Sanity check. The list of stored blocks should be connected. + TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back()); + + storedBlocks.push_back(block); + TLLM_CHECK(block->getPrevBlockInSeq() == nullptr + || block->getPrevBlockInSeq()->getHash() == searchRoot->getHash()); + auto oldHash = block->getHash(); + auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash()); + if (oldHash != newHash) + { + TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash); + block->setHash(newHash); + } + searchRoot = block; + numBlocksStoredForReuse++; + } + if (pinBlocks) { - TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash); - block->setHash(newHash); + searchRoot->incRefCount(); + pinnedBlockIds.push_back(searchRoot->getBlockId()); } - searchRoot = block; - numBlocksStoredForReuse++; } - if (pinBlocks) + catch (std::out_of_range const& ex) { - searchRoot->incRefCount(); - pinnedBlockIds.push_back(searchRoot->getBlockId()); + TLLM_LOG_WARNING("Out of range access, terminating storeBlocks early."); + // Prevent blocks following an invalid block from being reused. + break; } } if (mEventManager)