Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8836bb7
Prevent out-of-bounds read
thorjohnsen Dec 10, 2025
920aebd
precommit run
thorjohnsen Dec 10, 2025
972a084
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Dec 16, 2025
5d9baa4
Merge branch 'main' into user/tjohnsen/fix_5721661
thorjohnsen Jan 5, 2026
4b1faf6
Further strenghten protection against reuse of invalid blocks
thorjohnsen Jan 7, 2026
3ce0fcd
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Jan 7, 2026
d634693
precommit run
thorjohnsen Jan 7, 2026
ffd6178
Merge branch 'main' into user/tjohnsen/fix_5721661
thorjohnsen Jan 9, 2026
ceef915
Merge branch 'main' into user/tjohnsen/fix_5721661
thorjohnsen Jan 9, 2026
7b256e4
Merge branch 'main' into user/tjohnsen/fix_5721661
thorjohnsen Jan 9, 2026
fda698d
Merge branch 'main' into user/tjohnsen/fix_5721661
thorjohnsen Jan 9, 2026
3543975
Fix undefined behavior (imnplicit erase of object pointed to by itera…
thorjohnsen Jan 12, 2026
05b1329
Merge branch 'user/tjohnsen/fix_5721661' of github.com:thorjohnsen/Te…
thorjohnsen Jan 12, 2026
b94ffc2
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Jan 12, 2026
4085bf5
Bug fix
thorjohnsen Jan 13, 2026
3beb25c
Resolve merge conflicts
thorjohnsen Jan 13, 2026
0795caa
precommit run
thorjohnsen Jan 13, 2026
c47503f
Bug fix
thorjohnsen Jan 13, 2026
2510050
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Jan 13, 2026
233a6a8
Merge branch 'main' into user/tjohnsen/fix_5721661
thorjohnsen Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ class KVCacheBlock

void removeNextBlock(BlockKey const& blockKey);

void freeDescendantsRecursively();
void freeBlockAndAllDescendants();

//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
//! blockKey.
//! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were
Expand Down
151 changes: 96 additions & 55 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,32 @@ void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
mNextBlocks.erase(blockKey);
}

void KVCacheBlock::freeDescendantsRecursively()
{
bool hasChildren = !mNextBlocks.empty();
if (hasChildren)
{
for (auto it = mNextBlocks.begin(); it != mNextBlocks.end();)
{
it->second->freeDescendantsRecursively();
TLLM_LOG_DEBUG("KVCacheBlock::freeDescendantsRecursively - Freeing block %d", it->second->getBlockId());
it = mNextBlocks.erase(it);
}
}
mPrevBlock = nullptr;
}

void KVCacheBlock::freeBlockAndAllDescendants()
{
// free from previous block
if (mPrevBlock != nullptr)
{
mPrevBlock->removeNextBlock(mBlockKey);
mPrevBlock = nullptr;
}
freeDescendantsRecursively();
}

bool KVCacheBlock::isFull() const
{
return mIsFull;
Expand Down Expand Up @@ -956,19 +982,14 @@ void WindowBlockManager::freeLeafBlock(BlockPtr const& block)

void WindowBlockManager::freeChildren(BlockPtr const& block)
{
// Free all descendants of block
for (auto const& p : block->getNextBlocks())
{
auto childBlock = p.second;
freeChildren(childBlock);
}

// Free block
// Tell event manager we are freeing block
if (mEventManager && blockInRadixTree(block))
{
mEventManager->enqueueRemovedEvent(block, mWindowSize);
}
freeLeafBlock(block);

// Free block and all it's descendants from radix tree
block->freeBlockAndAllDescendants();
}

BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
Expand Down Expand Up @@ -1567,60 +1588,80 @@ std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::sto
auto searchRoot = mCachedBlocksRoot;
bool needMatch = true;

auto numBlocks = blockKeys.size();
// There is no guarantee that these vectors will be the same length.
// Only iterate as long as we have valid blockKey and blockId.
auto numBlocks = std::min(blockKeys.size(), blockIds.size());
std::vector<BlockPtr> storedBlocks;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{
auto const bid = blockIds[blockCnt];
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
auto& block = mAllBlocksById[bid];
auto const& blockKey = blockKeys[blockCnt];

auto [partialMatch, numMatched, matchedBlock]
= needMatch ? searchRoot->findMatchingBlock(blockKey, false, false) : std::make_tuple(false, 0, nullptr);
if (matchedBlock != nullptr)
{
// Found match
TLLM_LOG_DEBUG(
"%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(), matchedBlock->getBlockId());
searchRoot = matchedBlock;
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
// block can be freed and inserted at mFreePrimaryBlocks.begin()
}
else
{
// No match
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(),
block->getBlockId());
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
needMatch = false; // no matching needed for following blocks
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
block->setPrevBlock(searchRoot);
block->setPrevBlockInSeq(searchRoot);
searchRoot->addNextBlock(blockKey, block);

// Sanity check. The list of stored blocks should be connected.
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());

storedBlocks.push_back(block);
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
auto oldHash = block->getHash();
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
if (oldHash != newHash)
try
{
// Protect against blockIds being shorter than blockKeys.
auto const bid = blockIds.at(blockCnt);
TLLM_LOG_DEBUG("%s::storeBlocks - Searching match for block %d", mLogPrefix.c_str(), bid);
// We set blockId to an invalid value to indicate that a block has been released early for a limited
// attention layer. Make sure we don't store an invalid block because of this.
auto& block = mAllBlocksById.at(bid);
// Protect against blockKeys being shorter than blockIds.
auto const& blockKey = blockKeys.at(blockCnt);

// If either of the above error conditions occur, std::vector::at will throw an exception, which is caught
// further down. This will prevent an invalid block from being stored for reuse. The catch clause exits loop
// early, preventing blocks following an invalid block from being reused.

auto [partialMatch, numMatched, matchedBlock] = needMatch
? searchRoot->findMatchingBlock(blockKey, false, false)
: std::make_tuple(false, 0, nullptr);
if (matchedBlock != nullptr)
{
// Found match
TLLM_LOG_DEBUG("%s::storeBlocks - Found matching block %d, traverse", mLogPrefix.c_str(),
matchedBlock->getBlockId());
searchRoot = matchedBlock;
// TODO possible optimization: if bid != matchedBlock->getBlockId(),
// block can be freed and inserted at mFreePrimaryBlocks.begin()
}
else
{
// No match
TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure",
mLogPrefix.c_str(), block->getBlockId());
TLLM_CHECK_WITH_INFO(block->getBlockId() == bid,
"Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid));
needMatch = false; // no matching needed for following blocks
block->setBlockKey(blockKey, static_cast<SizeType32>(blockKey.uniqueTokens.size()) == mTokensPerBlock);
block->setPrevBlock(searchRoot);
block->setPrevBlockInSeq(searchRoot);
searchRoot->addNextBlock(blockKey, block);

// Sanity check. The list of stored blocks should be connected.
TLLM_CHECK(storedBlocks.empty() || block->getPrevBlock() == storedBlocks.back());

storedBlocks.push_back(block);
TLLM_CHECK(block->getPrevBlockInSeq() == nullptr
|| block->getPrevBlockInSeq()->getHash() == searchRoot->getHash());
auto oldHash = block->getHash();
auto newHash = BlockKeyHasher()(blockKey, searchRoot->getHash());
if (oldHash != newHash)
{
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
block->setHash(newHash);
}
searchRoot = block;
numBlocksStoredForReuse++;
}
if (pinBlocks)
{
TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash);
block->setHash(newHash);
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
}
searchRoot = block;
numBlocksStoredForReuse++;
}
if (pinBlocks)
catch (std::out_of_range const& ex)
{
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
TLLM_LOG_WARNING("Out of range access, terminating storeBlocks early.");
// Prevent blocks following an invalid block from being reused.
break;
}
}
if (mEventManager)
Expand Down