-
Notifications
You must be signed in to change notification settings - Fork 2k
[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg #10111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm | |
| } | ||
| } | ||
|
|
||
| std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks( | ||
| std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks( | ||
| std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks) | ||
| { | ||
| SizeType32 numBlocksStoredForReuse = 0; | ||
|
|
@@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s | |
|
|
||
| auto numBlocks = blockKeys.size(); | ||
| std::vector<BlockPtr> storedBlocks; | ||
| std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt; | ||
| std::vector<KVCacheBlock::IdType> pinnedBlockIds; | ||
| for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) | ||
| { | ||
| auto const bid = blockIds[blockCnt]; | ||
|
|
@@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s | |
| if (pinBlocks) | ||
| { | ||
| searchRoot->incRefCount(); | ||
| pinnedBlockIds.push_back(searchRoot->getBlockId()); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have a block already matched and in the search tree, do we need to pin it again?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pin block is only used in disaggregated serving. The goal is to make sure the blocks that are needed for the generation server are not evicted. The cycle is:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thank you for explaining the mechanism. Since you mentioned the transmission between ctx and gen server, does the two servers hold its own copy of search tree? May you point me to code of such logic? I imagine that a block that has been transmitted from ctx to gen may be used by a new sequence in the ctx server and invalidating the block for reuse under ctx server, but the block copied to gen should still be usable. |
||
| } | ||
| lastStoredId = searchRoot->getBlockId(); | ||
| } | ||
| if (mEventManager) | ||
| { | ||
| mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); | ||
| } | ||
| return {numBlocksStoredForReuse, lastStoredId}; | ||
| return {numBlocksStoredForReuse, pinnedBlockIds}; | ||
| } | ||
|
|
||
| void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) | ||
|
|
@@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c | |
| return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{}; | ||
| } | ||
|
|
||
| std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse( | ||
| std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse( | ||
| GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks) | ||
| { | ||
| std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt; | ||
| std::vector<KVCacheBlock::IdType> pinnedBlockIds; | ||
| for (auto& [_, manager] : mWindowBlockManagers) | ||
| { | ||
| lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); | ||
| pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); | ||
| } | ||
| return lastStoredId; | ||
| return pinnedBlockIds; | ||
| } | ||
|
|
||
| std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks( | ||
|
|
@@ -1767,15 +1767,15 @@ void BlockManager::pinBlocks(GenerationRequest& sequence) | |
| } | ||
| } | ||
|
|
||
| void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) | ||
| void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) | ||
| { | ||
| // Use the first window size | ||
| if (mWindowBlockManagers.empty()) | ||
| { | ||
| return; | ||
| } | ||
| auto& firstManager = mWindowBlockManagers.begin()->second; | ||
| firstManager.unpinBlocksById(blockId); | ||
| firstManager.unpinBlocksById(blockIds); | ||
| } | ||
|
|
||
| void WindowBlockManager::pinBlocks(GenerationRequest& sequence) | ||
|
|
@@ -1788,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence) | |
| } | ||
| } | ||
|
|
||
| void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) | ||
| void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) | ||
| { | ||
| if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size()) | ||
| if (blockIds.empty()) | ||
| { | ||
| return; | ||
| } | ||
| auto block = mAllBlocksById[blockId]; | ||
| while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) | ||
|
|
||
| for (auto const& blockId : blockIds) | ||
| { | ||
| block->decRefCount(); | ||
| if (!block->hasRefs()) | ||
| TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(), | ||
| "Block id %d is out of range", blockId); | ||
| auto block = mAllBlocksById[blockId]; | ||
| if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) | ||
| { | ||
| mEvictionPolicy->releaseBlock(block); | ||
| block->decRefCount(); | ||
| if (!block->hasRefs()) | ||
| { | ||
| mEvictionPolicy->releaseBlock(block); | ||
| } | ||
| } | ||
| block = std::move(block->getPrevBlock()); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1870,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< | |
| (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); | ||
| } | ||
|
|
||
| std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse( | ||
| std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse( | ||
| GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks) | ||
| { | ||
| auto constexpr beamIdx = 0; | ||
|
|
@@ -1883,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse( | |
| auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1; | ||
| auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true); | ||
| auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); | ||
| return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second; | ||
|
|
||
| auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks); | ||
|
|
||
| return pinnedBlockIds; | ||
| } | ||
|
|
||
| std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks( | ||
|
|
@@ -1922,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks( | |
| std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), | ||
| [](BlockPtr const& block) { return block->getBlockId(); }); | ||
|
|
||
| auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds); | ||
| auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds); | ||
| TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), | ||
| sequence.getRequestId(), numBlocksStoredForReuse); | ||
| } | ||
|
|
@@ -2499,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence( | |
| return lastStoredId; | ||
| } | ||
|
|
||
| std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse( | ||
| std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse( | ||
| RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks) | ||
| { | ||
| TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); | ||
| auto& sequence = getSequence(requestId); | ||
| std::optional<KVCacheBlock::IdType> lastStoredId | ||
| = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); | ||
| auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); | ||
| TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); | ||
| return lastStoredId; | ||
| return pinnedBlockIds; | ||
| } | ||
|
|
||
| void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) | ||
|
|
@@ -2522,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId) | |
| mBlockManager.pinBlocks(sequence); | ||
| } | ||
|
|
||
| void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId) | ||
| void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) | ||
| { | ||
| mBlockManager.unpinBlocksById(blockId); | ||
| mBlockManager.unpinBlocksById(blockIds); | ||
| } | ||
|
|
||
| SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Should add comments for return values.