Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ class KVCacheBlock

[[nodiscard]] bool isFull() const;

[[nodiscard]] bool isShared() const;

[[nodiscard]] bool isLeaf() const;

void setPriority(executor::RetentionPriority priority);
Expand Down Expand Up @@ -608,8 +606,6 @@ class WindowBlockManager
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams);

void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);

[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);

Expand Down Expand Up @@ -1008,10 +1004,6 @@ class BlockManager
void addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock);

void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);

void replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx);

std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);

Expand All @@ -1028,9 +1020,6 @@ class BlockManager

void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);

void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
SizeType32 blockIdx, KVCacheBlock::IdType blockId, SizeType32 windowSize) const;

// WILL NOT WORK FOR VARIABLE WINDOW ATTENTION
[[nodiscard]] std::optional<BlockKey> findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
Expand Down Expand Up @@ -1262,9 +1251,6 @@ class BlockManager
//! \brief Update cache offsets for blocks initiated from sequence
void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);

//! \brief Update cache offsets for block at index
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);

//! \brief Add/detach block(s) to/from the sequence if needed
//! \details When we need a new block, we add it. For sliding window
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
Expand Down
118 changes: 0 additions & 118 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,39 +56,6 @@ inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept
return static_cast<uint8_t>((hashPart >> (24 - byteIdx * 8)) & 0xFF);
}

//! \brief Get all blocks in a sequence by traversing backwards from the last block.
//! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from
//! \return Vector of BlockPtr-s in sequence order
std::vector<BlockPtr> getAllSequenceBlocks(BlockPtr lastBlock)
{
// First count the number of blocks to pre-allocate the vector
auto currentBlock = lastBlock;
size_t blockCount = 0;
while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
blockCount++;
currentBlock = currentBlock->getPrevBlockInSeq();
}

if (blockCount == 0)
{
return {};
}
// Create and pre-allocate the vector with the correct size
std::vector<BlockPtr> sequenceBlocks(blockCount);

// Now traverse backwards and fill from the end
currentBlock = lastBlock;
size_t currentIndex = blockCount - 1;
while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
sequenceBlocks[currentIndex--] = currentBlock;
currentBlock = currentBlock->getPrevBlockInSeq();
}

return sequenceBlocks;
}

} // namespace

namespace tensorrt_llm::batch_manager::kv_cache_manager
Expand Down Expand Up @@ -317,12 +284,6 @@ bool KVCacheBlock::hasRefs() const
return mRefCount > 0;
}

bool KVCacheBlock::isShared() const
{
// block is considered shared if ready for reuse
return mRefCount > 1 || mPrevBlock != nullptr;
}

bool KVCacheBlock::hasSchedulingRefs() const
{
return mSchedulingRefCount > 0;
Expand Down Expand Up @@ -1019,12 +980,6 @@ void WindowBlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims
}
}

void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
SizeType32 blockIdx, KVCacheBlock::IdType blockId, SizeType32 windowSize) const
{
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}

void BlockManager::onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock, SizeType32 windowSize,
executor::KvCacheTransferMode mode, std::string const& directory)
{
Expand Down Expand Up @@ -1459,11 +1414,6 @@ void WindowBlockManager::addBlockToAllBeams(BlockPtr& block, GenerationRequest&
}
}

void BlockManager::allocateBlock(GenerationRequest& sequence, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).allocateBlock(sequence, false);
}

void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAmongBeams)
{
auto const beamWidth = sequence.getBeamWidth();
Expand Down Expand Up @@ -1567,57 +1517,6 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
return {numBlocksStoredForReuse, lastStoredId};
}

void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
{
mWindowBlockManagers.at(windowSize).replaceSharedBlock(sequence, blockIdx);
}

void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx)
{
auto const requestId = sequence.getRequestId();
auto const beamWidth = sequence.getBeamWidth();
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);

if (!allocatedBlocks.at((blockIdx + 1) * beamWidth - 1)->isShared())
{
return;
}
BlockKey blockKey = allocatedBlocks.at(blockIdx * beamWidth)->getBlockKey();
bool isFull = allocatedBlocks.at(blockIdx * beamWidth)->isFull();

// Free shared block
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = allocatedBlocks.at(blockIdx * beamWidth + beamIdx);
block->decRefCount();
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
}

// Allocate new blocks
TLLM_CHECK_WITH_INFO(hasFreeBlocks(beamWidth), "Can't allocate new blocks. No free blocks left.");
for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto block = getFreeBlock(sequence, executor::KvCacheRetentionConfig::kDefaultRetentionPriority, std::nullopt,
sequence.getTransferMode(), sequence.getDirectory());
block->incRefCount();
if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0)
{
block->setPrevBlockInSeq(nullptr);
}
else
{
block->setPrevBlockInSeq(mAllBlocksById.at(sequence.getCacheBlockIds(mWindowSize)[beamIdx].back()));
}
block->setBlockKey(blockKey, isFull);
block->setHash();
sequence.changeCacheBlock(mWindowSize, beamIdx, blockIdx, block->getBlockId());
allocatedBlocks.at(blockIdx * beamWidth + beamIdx) = block;
}
}

void BlockManager::releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).releaseLastBlock(sequence);
Expand Down Expand Up @@ -2194,23 +2093,6 @@ void WindowBlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence
}
}

void BlockManager::updateCacheBlockOffsetsAtIdx(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
auto const beamWidth = sequence.getBeamWidth();

auto* offsetsPtr = bufferCast<tk::KVCacheIndex>(cacheBlocksTensor);
auto const& offsetsShape = cacheBlocksTensor.getShape();

for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const& beamCacheBlock = cacheBlocks[beamIdx];
auto const blockId = beamCacheBlock.at(blockIdx);
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}

void KVCacheManager::addToken(RequestIdType requestId)
{
// TODO: add streamLLM support
Expand Down
Loading