Skip to content

Commit d798d66

Browse files
authored
[TRTLLM-7731][feat] Avoid over-allocation of KV cache for transmission in disagg with CP (#8145)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent bba2519 commit d798d66

File tree

9 files changed

+279
-117
lines changed

9 files changed

+279
-117
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,22 @@
4242
namespace tensorrt_llm::batch_manager::kv_cache_manager
4343
{
4444

45-
BlockRange getBlockRangeForSending(
46-
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, BlockKey const& lastBlockKey, int32_t indexFromEnd)
45+
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
46+
BlockKey const& lastBlockKey, int32_t indexFromEnd, bool recvSideHasCP)
4747
{
4848
auto poolNum = cacheManager->getBlockManager().getNumPools();
49-
if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0)
49+
// Note: When recv side has CP, the requested seqLen is lesser than seqLen on the sender side as seqLen is
50+
// distributed among CP ranks. So, we transfer all blocks from send side.
51+
if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0 || recvSideHasCP)
5052
{
5153
// disable reuse path, and vwsa don't support reuse.
5254
bool needSendAllForWindow = common::getEnvKVCacheTransferAllBlocksForWindow();
5355

5456
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId);
55-
// auto inputLen = llmRequest.getPromptLen();
5657

5758
auto const& windowsMetadata = cacheManager->getBlockManager().getWindowSizesMetadata();
5859

59-
if ((windowsMetadata.size() == 1 || needSendAllForWindow))
60+
if (windowsMetadata.size() == 1 || needSendAllForWindow || recvSideHasCP)
6061
{
6162
return blockRange;
6263
}
@@ -85,10 +86,11 @@ BlockRange getBlockRangeForSending(
8586
}
8687

8788
BlockRange getBlockRangeForReceiving(
88-
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse)
89+
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse, bool recvSideHasCP)
8990
{
9091
auto poolNum = cacheManager->getBlockManager().getNumPools();
91-
if (poolNum == 1 && srcEnableBlockReuse)
92+
// Note: When recv side has CP, we request all blocks from send side right now.
93+
if (poolNum == 1 && srcEnableBlockReuse && !recvSideHasCP)
9294
{
9395
// Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones.
9496
auto windowSize = cacheManager->getBlockManager().getWindowSizesMetadata().begin()->first;
@@ -121,9 +123,8 @@ BlockRange getBlockRangeForReceiving(
121123
}
122124

123125
auto const& windowsMetadata = cacheManager->getBlockManager().getWindowSizesMetadata();
124-
if (windowsMetadata.size() == 1 || common::getEnvKVCacheTransferAllBlocksForWindow())
126+
if (windowsMetadata.size() == 1 || common::getEnvKVCacheTransferAllBlocksForWindow() || recvSideHasCP)
125127
{
126-
127128
return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId);
128129
}
129130
auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId);

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class TransferSession;
4343
namespace tensorrt_llm::batch_manager::kv_cache_manager
4444
{
4545
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
46-
BlockKey const& lastBlockKey, SizeType32 indexFromEnd);
46+
BlockKey const& lastBlockKey, SizeType32 indexFromEnd, bool recvSideHasCP = false);
4747

4848
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
4949
using Connection = tensorrt_llm::executor::kv_cache::Connection;
@@ -52,8 +52,8 @@ using BaseKVCacheManager = kv_cache_manager::BaseKVCacheManager;
5252
using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager;
5353
using BlockRange = kv_cache_manager::BlockRange;
5454

55-
BlockRange getBlockRangeForReceiving(
56-
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse);
55+
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
56+
bool srcEnableBlockReuse, bool recvSideHasCP = false);
5757

5858
// Used to support the cache transmission with different layouts and different protocols.
5959
class BaseCacheFormatter

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,31 +37,6 @@
3737
namespace tensorrt_llm::batch_manager::kv_cache_manager
3838
{
3939

40-
int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict)
41-
{
42-
TLLM_CHECK(cpRank >= 0 && cpRank < cpSize);
43-
if (cpSize == 1)
44-
{
45-
return numTotalBlocks;
46-
}
47-
// NOTE: Non-strict mode may over-allocate blocks when numTotalBlocks is not divisible by cpSize.
48-
// This is a known limitation and will be addressed in a future MR.
49-
if (!strict)
50-
{
51-
// Simple ceiling division.
52-
return (numTotalBlocks + cpSize - 1) / cpSize;
53-
}
54-
// In strict mode, blocks are distributed among CP ranks in a round-robin fashion as evenly as possible.
55-
// When the number of blocks is not divisible by cpSize, the remainder shall be distributed evenly among
56-
// lowest-indexed CP ranks (let's call them overflow ranks).
57-
int numBlocksCurrRank = numTotalBlocks / cpSize;
58-
if (numTotalBlocks % cpSize > cpRank)
59-
{
60-
numBlocksCurrRank++;
61-
}
62-
return numBlocksCurrRank;
63-
}
64-
6540
// some context rank in connection
6641
std::vector<size_t> MLACacheFormatter::pickRecvConnections(
6742
size_t numConnections, CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const
@@ -145,7 +120,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
145120
int blockNum = 0;
146121
std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks;
147122
auto const numPools = mCacheManager->getBlockManager().getNumPools();
148-
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd);
123+
bool const recvSideHasCP = destConfig.getParallelConfig().mContextParallelism > 1;
124+
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd, recvSideHasCP);
149125
auto const& windowSizes = blockRange.getWindowSizes();
150126
TLLM_CHECK_WITH_INFO(
151127
static_cast<int>(windowSizes.size()) == numPools, "window sizes should be the same as numPools");
@@ -204,7 +180,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
204180
auto const idx = cpDomainIdx * pPDomainSize + ppDomainIdx;
205181
// Note: contextCP is always 1. So, cpDomainSize == genCPSize and cpDomainIdx == genCPRank.
206182
auto const peerBlockNum
207-
= getBlockNumAccountingForCP(cpDomainIdx, cPDomainSize, blockNum, /*strict=*/false);
183+
= executor::kv_cache::getBlockNumAccountingForCP(cpDomainIdx, cPDomainSize, blockNum);
208184
bufferSizeForTarget[idx] = blockSizePerLayer * peerAttentionLayerNum * peerBlockNum;
209185
}
210186
}
@@ -346,7 +322,9 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
346322
auto const& connections = session.getConnections();
347323
auto& bufferManager = session.getBufferManager();
348324
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
349-
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
325+
bool const recvSideHasCP = selfConfig.getParallelConfig().mContextParallelism > 1;
326+
auto blockRange
327+
= getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), recvSideHasCP);
350328
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
351329
std::vector<runtime::ITensor::SharedPtr> outputBuffers;
352330
auto const numPools = mCacheManager->getBlockManager().getNumPools();

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,6 @@
2222
namespace tensorrt_llm::batch_manager::kv_cache_manager
2323
{
2424

25-
/**
26-
* @brief Calculate the number of blocks allocated to a specific Context Parallelism (CP) rank.
27-
*
28-
* This function determines how many blocks should be allocated to a given CP rank when
29-
* distributing a total number of blocks across multiple CP ranks. It supports two distribution
30-
* modes: strict and non-strict.
31-
*
32-
* @param cpRank The rank (index) of the current CP process. Must be in range [0, cpSize).
33-
* @param cpSize The total number of CP ranks/processes in the parallel group.
34-
* @param numTotalBlocks The total number of blocks to be distributed across all CP ranks.
35-
* @param strict Flag controlling the distribution strategy:
36-
* - true: Use strict round-robin distribution with exact allocation
37-
* - false: Use ceiling division which may over-allocate
38-
*
39-
* @return The number of blocks allocated to the specified CP rank.
40-
*/
41-
int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict);
42-
4325
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
4426
// parallel topology is completely identical, making it the preferred method.
4527
class MLACacheFormatter final : public BaseCacheFormatter

cpp/tensorrt_llm/common/envUtils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,12 @@ bool getEnvUseNixlKvCache()
278278
return useNixlKvCache;
279279
}
280280

281+
bool getEnvUseRoundRobinBlockDistForCP()
282+
{
283+
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
284+
return useRoundRobinBlockDistForCP;
285+
}
286+
281287
std::string getEnvUCXInterface()
282288
{
283289
static std::once_flag flag;

cpp/tensorrt_llm/common/envUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ bool getEnvUseUCXKvCache();
8282
bool getEnvUseMPIKvCache();
8383
bool getEnvUseNixlKvCache();
8484

85+
bool getEnvUseRoundRobinBlockDistForCP();
86+
8587
std::string getEnvUCXInterface();
8688

8789
std::string getEnvNixlInterface();

0 commit comments

Comments
 (0)