Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ class WindowBlockManager

void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);

[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);

void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
Expand Down Expand Up @@ -853,8 +853,8 @@ class WindowBlockManager
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
//! \return Pair of (num blocks stored for reuse, vector of pinned block IDs).
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool pinBlocks = false);

Expand Down Expand Up @@ -886,8 +886,8 @@ class WindowBlockManager

[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);

//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);
//! \brief Unpin blocks by block ids directly
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);

void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
Expand Down Expand Up @@ -1103,7 +1103,7 @@ class BlockManager
std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);

[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);

void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
Expand All @@ -1112,7 +1112,7 @@ class BlockManager
/// @param sequence The generation request whose blocks should be pinned.
void pinBlocks(GenerationRequest& sequence);

void unpinBlocksById(KVCacheBlock::IdType blockId);
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds);

void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);

Expand All @@ -1133,7 +1133,7 @@ class BlockManager
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");

[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
[[nodiscard]] std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize, bool pinBlocks = false)
{
Expand Down Expand Up @@ -1584,7 +1584,7 @@ class BaseKVCacheManager
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;

/// \brief Store blocks for reuse for a given request id
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should add comments for return values.

LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
= 0;

Expand Down Expand Up @@ -1678,7 +1678,7 @@ class BaseKVCacheManager
BlockKey const& blockKey, SizeType32 windowSize)
= 0;

virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
virtual void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) = 0;
};

class KVCacheManager : public BaseKVCacheManager
Expand Down Expand Up @@ -1939,7 +1939,7 @@ class KVCacheManager : public BaseKVCacheManager
//! \brief Store newest blocks for reuse
void storeNewBlock(LlmRequest const& llmRequest) override;

[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
[[nodiscard]] std::vector<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;

[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
Expand All @@ -1960,7 +1960,7 @@ class KVCacheManager : public BaseKVCacheManager

void pinBlocks(LlmRequest::RequestIdType requestId) override;

void unpinBlocksById(KVCacheBlock::IdType blockId) override;
void unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds) override;

std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;

Expand Down
6 changes: 6 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,12 @@ class GenericLlmRequest
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
}

[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
{
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
}

[[nodiscard]] bool isTimedOut() const
{
if (!mAllottedTimeMs.has_value())
Expand Down
61 changes: 34 additions & 27 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm
}
}

std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::pair<SizeType32, std::vector<KVCacheBlock::IdType>> WindowBlockManager::storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds, bool pinBlocks)
{
SizeType32 numBlocksStoredForReuse = 0;
Expand All @@ -1569,7 +1569,7 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s

auto numBlocks = blockKeys.size();
std::vector<BlockPtr> storedBlocks;
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt)
{
auto const bid = blockIds[blockCnt];
Expand Down Expand Up @@ -1620,14 +1620,14 @@ std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> WindowBlockManager::s
if (pinBlocks)
{
searchRoot->incRefCount();
pinnedBlockIds.push_back(searchRoot->getBlockId());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a block already matched and in the search tree, do we need to pin it again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pin block is only used in disaggregated serving. The goal is to make sure the blocks that are needed for the generation server are not evicted. The cycle is:

  1. Context server stores and increases the ref count by one additional number if block reuse is enabled and it is a context-only request.
  2. The decode server starts fetching those blocks from the reuse tree.
  3. The context server decreases the ref count once the transmission has been completed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thank you for explaining the mechanism.

Since you mentioned the transmission between ctx and gen server, does the two servers hold its own copy of search tree? May you point me to code of such logic? I imagine that a block that has been transmitted from ctx to gen may be used by a new sequence in the ctx server and invalidating the block for reuse under ctx server, but the block copied to gen should still be usable.

}
lastStoredId = searchRoot->getBlockId();
}
if (mEventManager)
{
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
}
return {numBlocksStoredForReuse, lastStoredId};
return {numBlocksStoredForReuse, pinnedBlockIds};
}

void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
Expand Down Expand Up @@ -1715,15 +1715,15 @@ std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::c
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};
}

std::optional<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> BlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
std::vector<KVCacheBlock::IdType> pinnedBlockIds;
for (auto& [_, manager] : mWindowBlockManagers)
{
lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
pinnedBlockIds = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
}
return lastStoredId;
return pinnedBlockIds;
}

std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
Expand Down Expand Up @@ -1767,15 +1767,15 @@ void BlockManager::pinBlocks(GenerationRequest& sequence)
}
}

void BlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void BlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
// Use the first window size
if (mWindowBlockManagers.empty())
{
return;
}
auto& firstManager = mWindowBlockManagers.begin()->second;
firstManager.unpinBlocksById(blockId);
firstManager.unpinBlocksById(blockIds);
}

void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
Expand All @@ -1788,21 +1788,26 @@ void WindowBlockManager::pinBlocks(GenerationRequest& sequence)
}
}

void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void WindowBlockManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
if (blockId < 0 || static_cast<size_t>(blockId) >= mAllBlocksById.size())
if (blockIds.empty())
{
return;
}
auto block = mAllBlocksById[blockId];
while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)

for (auto const& blockId : blockIds)
{
block->decRefCount();
if (!block->hasRefs())
TLLM_CHECK_WITH_INFO(blockId >= 0 && static_cast<size_t>(blockId) < mAllBlocksById.size(),
"Block id %d is out of range", blockId);
auto block = mAllBlocksById[blockId];
if (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
mEvictionPolicy->releaseBlock(block);
block->decRefCount();
if (!block->hasRefs())
{
mEvictionPolicy->releaseBlock(block);
}
}
block = std::move(block->getPrevBlock());
}
}

Expand Down Expand Up @@ -1870,7 +1875,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<
(void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}

std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
auto constexpr beamIdx = 0;
Expand All @@ -1883,7 +1888,10 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second;

auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);

return pinnedBlockIds;
}

std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
Expand Down Expand Up @@ -1922,7 +1930,7 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(),
[](BlockPtr const& block) { return block->getBlockId(); });

auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds);
auto [numBlocksStoredForReuse, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds);
TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(),
sequence.getRequestId(), numBlocksStoredForReuse);
}
Expand Down Expand Up @@ -2499,15 +2507,14 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
return lastStoredId;
}

std::optional<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
std::vector<KVCacheBlock::IdType> KVCacheManager::storeBlocksForReuse(
RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks)
{
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
auto& sequence = getSequence(requestId);
std::optional<KVCacheBlock::IdType> lastStoredId
= mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
auto pinnedBlockIds = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks);
TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
return lastStoredId;
return pinnedBlockIds;
}

void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId)
Expand All @@ -2522,9 +2529,9 @@ void KVCacheManager::pinBlocks(RequestIdType requestId)
mBlockManager.pinBlocks(sequence);
}

void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId)
void KVCacheManager::unpinBlocksById(std::vector<KVCacheBlock::IdType> const& blockIds)
{
mBlockManager.unpinBlocksById(blockId);
mBlockManager.unpinBlocksById(blockIds);
}

SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const
Expand Down
12 changes: 6 additions & 6 deletions cpp/tensorrt_llm/executor/executorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2179,11 +2179,11 @@ void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissio
auto req = item.request;
if (req->isDisaggContextCompleteState())
{
// If lastBlockId was tracked, unpin it. Otherwise, just terminate.
// If pinnedBlockIds were tracked, unpin them. Otherwise, just terminate.
auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && item.lastBlockId.has_value())
if (kvMgr && !item.pinnedBlockIds.empty())
{
kvMgr->unpinBlocksById(item.lastBlockId.value());
kvMgr->unpinBlocksById(item.pinnedBlockIds);
}
else
{
Expand Down Expand Up @@ -2234,14 +2234,14 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses(
// move the in transmission requests to another tracker
if (llmReq->isDisaggContextTransmissionState())
{
std::optional<SizeType32> lastBlockId{};
std::vector<SizeType32> pinnedBlockIds{};
auto kvMgr = mModel->getKVCacheManager();
if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow())
{
lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
pinnedBlockIds = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true);
mModel->terminateRequest(llmReq);
}
inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId});
inTransmissionRequests.push_back(InTransmissionItem{*it, pinnedBlockIds});
}
finishedRequests.push_back(*it);
it = activeRequests.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/executor/executorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ class Executor::Impl
using RequestList = std::list<LlmRequestPtr>;

// When block reuse is enabled for context worker for disaggregated serving,
// we need to store the last block id so that we can unpin the block when
// we need to store the pinned block ids so that we can unpin them when
// the request is finished.
struct InTransmissionItem
{
LlmRequestPtr request;
std::optional<SizeType32> lastBlockId;
std::vector<SizeType32> pinnedBlockIds;
};

using InTransList = std::list<InTransmissionItem>;
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ void initBindings(nb::module_& m)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam"))
.def_prop_ro("is_finished", &GenLlmReq::isFinished)
.def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_prop_ro("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_prop_rw(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease);
}

std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{
NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks);
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ void initBindings(pybind11::module_& m)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_property(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
requestId, llmRequest, pinOnRelease);
}

std::optional<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
requestId, llmRequest, pinBlocks);
}

Expand Down
4 changes: 3 additions & 1 deletion cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4066,11 +4066,13 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
kvCacheManager.pinBlocks(requestId);
auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId);
ASSERT_TRUE(lastBlockIdOpt.has_value());
auto const& allBlockIds = kvCacheManager.getCacheBlockIds(requestId, maxAttentionWindow)[0];
std::vector<SizeType32> pinnedBlockIds(allBlockIds.begin(), allBlockIds.end());
(void) kvCacheManager.removeSequence(requestId, llmRequest);
auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks();
EXPECT_LT(freeAfterRemovePinned, totalBlocks);

kvCacheManager.unpinBlocksById(lastBlockIdOpt.value());
kvCacheManager.unpinBlocksById(pinnedBlockIds);
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
EXPECT_EQ(freeAfterUnpin, totalBlocks);
}
Expand Down
Loading