Skip to content

Commit 9cee32a

Browse files
authored
[https://nvbugs/5625990][fix] Respect VSWA scheme when doing block store for reuse and load block for reuse in KV cache manager (#10183)
Signed-off-by: eopXD <[email protected]>
1 parent 2f8d6d2 commit 9cee32a

File tree

4 files changed

+122
-79
lines changed

4 files changed

+122
-79
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ class GenerationRequest
380380
, mBeamWidth(beamWidth)
381381
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
382382
, mNumFrontBlocksRemoved(0)
383+
, mCurrentPrepopulatedPromptLen(std::numeric_limits<SizeType32>::max())
383384
{
384385
auto const numWindowSizes = windowSizeToMetadata.size();
385386
mCacheBlockIds.reserve(numWindowSizes);
@@ -500,6 +501,20 @@ class GenerationRequest
500501
return mKvCacheRetentionConfig.getDirectory();
501502
}
502503

504+
[[nodiscard]] SizeType32 getCurrentPrepopulatedPromptLen() const
505+
{
506+
return mCurrentPrepopulatedPromptLen;
507+
}
508+
509+
void setCurrentPrepopulatedPromptLen(SizeType32 currentPrepopulatedPromptLen)
510+
{
511+
TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen,
512+
"currentPrepopulatedPromptLen must be updated non-increasingly due to the "
513+
"assumption that smaller window sizes have shorter or equal"
514+
"currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks.");
515+
mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen;
516+
}
517+
503518
private:
504519
// Request id of the sequence
505520
LlmRequest::RequestIdType mRequestId;
@@ -517,6 +532,8 @@ class GenerationRequest
517532
SizeType32 mNumFrontBlocksRemoved;
518533
// Set of used blocks by the sequence
519534
std::set<KVCacheBlock::IdType> mUsedBlocks;
535+
// Current prepopulated prompt length
536+
SizeType32 mCurrentPrepopulatedPromptLen;
520537
};
521538

522539
// attach metadata to a pool pointer

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
12241224
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
12251225
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
12261226
: std::make_tuple(false, 0, nullptr);
1227-
if (matchingBlock != nullptr)
1227+
if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen())
12281228
{
12291229
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
12301230

@@ -1338,6 +1338,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
13381338
}
13391339
}
13401340

1341+
sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens);
13411342
return numMatchedTokens;
13421343
}
13431344

@@ -1731,9 +1732,22 @@ std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
17311732
// Released block will be stored when reuse is enabled.
17321733
// Reuse is implied to be enabled if llmRequest is provided.
17331734
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
1735+
1736+
// For now, the attention kernel only accepts a single
1737+
// "prepopulatedPromptLen", that is, all window sizes will use the same
1738+
// prepopulated prompt length, so it is meaningless right now to save
1739+
// blocks only for a certain window size while blocks in the other
1740+
// window size are not valid for saving for reuse.
1741+
bool isAllWindowSizesValidForStoreForReuse = true;
1742+
for (auto& [windowSize, manager] : mWindowBlockManagers)
1743+
{
1744+
isAllWindowSizesValidForStoreForReuse &= manager.isSequenceValidForStoreForReuse(sequence.getRequestId());
1745+
}
1746+
17341747
for (auto& [_, manager] : mWindowBlockManagers)
17351748
{
1736-
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1)
1749+
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1
1750+
|| !isAllWindowSizesValidForStoreForReuse)
17371751
{
17381752
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
17391753
}

0 commit comments

Comments
 (0)