Skip to content

Commit 882c67a

Browse files
committed
[TRTLLM-10264][feat] Support attention DP + Helix CP
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent 4a1b2e2 commit 882c67a

File tree

9 files changed

+194
-52
lines changed

9 files changed

+194
-52
lines changed

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,18 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
151151
int TPSizeInDPGroup
152152
= mCacheState->getParallelConfig().mTensorParallelism / mCacheState->getParallelConfig().mDPsize;
153153
int DPSize = mCacheState->getParallelConfig().mDPsize;
154-
int TPRankInDPGroup = worldConfig.getTensorParallelRank() % TPSizeInDPGroup;
155154

156-
int DPRank = (worldConfig.getRank() - TPSizeInDPGroup * DPSize * worldConfig.getPipelineParallelRank()
157-
- TPRankInDPGroup)
158-
/ TPSizeInDPGroup;
159-
// <PP,DP,TP>
155+
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
156+
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
157+
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
158+
int DPRank = worldConfig.getTensorParallelRank() / TPSizeInDPGroup;
159+
// <PP,DP,TP,CP>
160160
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
161161
if (worldConfig.isTensorParallel())
162162
{
163+
// Group ranks with same (ppRank, DPRank) accounting for CP.
163164
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
164-
mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
165+
mGroupComm->split(worldConfig.getPipelineParallelRank() * DPSize + DPRank, worldConfig.getRank()));
165166
}
166167
}
167168
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp

Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,10 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
552552
mCpSize = genCp;
553553
}
554554

555-
mTpRank = mRankInInstance % mTpSize;
555+
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
556+
mCpRank = mRankInInstance % mCpSize;
557+
mTpRank = (mRankInInstance % (mTpSize * mCpSize)) / mCpSize;
556558
mPpRank = mRankInInstance / (mTpSize * mCpSize);
557-
mCpRank = (mRankInInstance % (mTpSize * mCpSize)) / mTpSize;
558559
mContextRankSize = contextRanks;
559560
mGenRankSize = genRanks;
560561
mContextTpSize = contextTp;
@@ -887,7 +888,16 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
887888
auto makeLlmRequestWithDP(SizeType32 length, LlmRequest::RequestIdType requestId, int contextDpRank)
888889
{
889890
constexpr SizeType32 maxNewTokens{1};
890-
texec::Request request{VecTokens(length), maxNewTokens};
891+
auto const tokensPerBlock = mContextCacheState->getModelConfig().mTokensPerBlock;
892+
893+
std::optional<CPMetaData> cpMetaData;
894+
int seqLen = length;
895+
if (mCpSize > 1)
896+
{
897+
cpMetaData.emplace(length, tokensPerBlock, mCpRank, mCpSize);
898+
seqLen = cpMetaData.value().mSeqLenOnThisCPRank;
899+
}
900+
texec::Request request{VecTokens(seqLen, seqLen), maxNewTokens};
891901

892902
auto state = std::make_unique<texec::DataTransceiverState>();
893903
state->setCommState(texec::kv_cache::CommState{*mContextCommState});
@@ -905,7 +915,6 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
905915
request.setContextPhaseParams(std::move(stats));
906916
auto llmRequestPtr = std::make_unique<LlmRequest>(requestId, std::move(request));
907917

908-
std::optional<CPMetaData> cpMetaData;
909918
return std::make_unique<WrappedLlmRequest>(std::move(llmRequestPtr), cpMetaData);
910919
}
911920

@@ -1428,6 +1437,27 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
14281437
{
14291438
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
14301439
}
1440+
// Filter request lengths based on CP requirements.
1441+
// Each request must have at least one block per CP rank to be valid for CP tests.
1442+
std::vector<int> lenList = {60, 30, 60, 10};
1443+
if (genCp > 1)
1444+
{
1445+
std::vector<int> updatedLenList;
1446+
for (auto len : lenList)
1447+
{
1448+
if (len > tokensPerBlock * (genCp - 1))
1449+
{
1450+
updatedLenList.push_back(len);
1451+
}
1452+
}
1453+
if (updatedLenList.empty())
1454+
{
1455+
GTEST_SKIP() << "Skipping test because not even one request has one block per genCP rank. tokensPerBlock="
1456+
<< tokensPerBlock << ", genCp=" << genCp;
1457+
}
1458+
lenList = updatedLenList;
1459+
}
1460+
14311461
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
14321462

14331463
if (mIsContext || mIsGeneration)
@@ -1438,7 +1468,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
14381468
setUpCacheTransceiver();
14391469
std::vector<std::shared_ptr<WrappedLlmRequest>> requests;
14401470
int requestId = 0;
1441-
for (auto len : {60, 30, 60, 10})
1471+
for (auto len : lenList)
14421472
{
14431473
requests.emplace_back(makeLlmRequestWithDP(len, requestId, requestId % contextTp));
14441474
requestId++;
@@ -1814,6 +1844,44 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1WithCPForMLA, AsymmetricalCacheTest,
18141844
/*generationDP*/ testing::Values(false),
18151845
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));
18161846

1847+
// Tests cases where there's non-trivial TP and PP on context side while non-trivial CP & DP on gen side.
1848+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithCPAndDPForMLA0, AsymmetricalCacheTestWithDP,
1849+
testing::Combine(/*contextTp*/ testing::Values(1, 2),
1850+
/*contextPp*/ testing::Values(1, 2),
1851+
/*contextCp*/ testing::Values(1),
1852+
/*genTp*/ testing::Values(2),
1853+
/*genPp*/ testing::Values(1),
1854+
/*genCp*/ testing::Values(2),
1855+
/*numLayers*/ testing::Values(4),
1856+
/*numHeads*/ testing::Values(1),
1857+
/*sizePerHead*/ testing::Values(4),
1858+
/*tokensPerBlock*/ testing::Values(8),
1859+
/*dataType*/ testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8),
1860+
/*kvFactor*/ testing::Values(1),
1861+
/*isMLA*/ testing::Values(true),
1862+
/*contextDP*/ testing::Values(false),
1863+
/*generationDP*/ testing::Values(true),
1864+
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));
1865+
1866+
// Tests cases where there's non-trivial DP on context side while non-trivial CP & DP on gen side.
1867+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithCPAndDPForMLA1, AsymmetricalCacheTestWithDP,
1868+
testing::Combine(/*contextTp*/ testing::Values(2, 4),
1869+
/*contextPp*/ testing::Values(1),
1870+
/*contextCp*/ testing::Values(1),
1871+
/*genTp*/ testing::Values(2),
1872+
/*genPp*/ testing::Values(1),
1873+
/*genCp*/ testing::Values(2),
1874+
/*numLayers*/ testing::Values(4),
1875+
/*numHeads*/ testing::Values(1),
1876+
/*sizePerHead*/ testing::Values(4),
1877+
/*tokensPerBlock*/ testing::Values(8),
1878+
/*dataType*/ testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8),
1879+
/*kvFactor*/ testing::Values(1),
1880+
/*isMLA*/ testing::Values(true),
1881+
/*contextDP*/ testing::Values(true),
1882+
/*generationDP*/ testing::Values(true),
1883+
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));
1884+
18171885
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
18181886
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
18191887
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
@@ -2226,8 +2294,8 @@ TEST(targetTest, CacheStateContextDP)
22262294
auto const verifyContext = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
22272295
int expectPPDomain, int expectTPDomain, bool expectNeedSend)
22282296
{
2229-
int contextDPRank = contextRank % contextTP;
2230-
int generationDPRank = generationRank % genTP;
2297+
int contextDPRank = (contextRank % (contextTP * contextCP)) / contextCP;
2298+
int generationDPRank = (generationRank % (genTP * genCP)) / genCP;
22312299
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
22322300
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
22332301

@@ -2239,12 +2307,12 @@ TEST(targetTest, CacheStateContextDP)
22392307
tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
22402308
genEnableDP, generationDPRank, genTP};
22412309

2242-
auto const contextTragetInfo
2310+
auto const contextTargetInfo
22432311
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
22442312

2245-
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
2246-
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
2247-
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
2313+
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
2314+
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
2315+
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
22482316
EXPECT_EQ(expectNeedSend, MLACacheFormatter::needSendCache(contextCache, genCache, contextRank));
22492317
};
22502318

@@ -2330,11 +2398,11 @@ TEST(targetTest, CacheStateContextDP)
23302398
contextTP = 1;
23312399
genTP = 2;
23322400

2333-
auto const verfiyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
2401+
auto const verifyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
23342402
int expectPPDomain, int expectTPDomain)
23352403
{
2336-
int contextDPRank = contextRank % contextTP;
2337-
int generationDPRank = generationRank % genTP;
2404+
int contextDPRank = (contextRank % (contextTP * contextCP)) / contextCP;
2405+
int generationDPRank = (generationRank % (genTP * genCP)) / genCP;
23382406
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
23392407
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
23402408

@@ -2346,17 +2414,17 @@ TEST(targetTest, CacheStateContextDP)
23462414
tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
23472415
genEnableDP, generationDPRank, genTP};
23482416

2349-
auto const contextTragetInfo
2417+
auto const contextTargetInfo
23502418
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank);
23512419

2352-
EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
2353-
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
2354-
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
2420+
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
2421+
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
2422+
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
23552423
};
23562424

2357-
verfiyGeneration(
2425+
verifyGeneration(
23582426
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
2359-
verfiyGeneration(
2427+
verifyGeneration(
23602428
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
23612429

23622430
contextTP = 1;
@@ -2366,9 +2434,9 @@ TEST(targetTest, CacheStateContextDP)
23662434
contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
23672435
genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);
23682436

2369-
verfiyGeneration(
2437+
verifyGeneration(
23702438
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
2371-
verfiyGeneration(
2439+
verifyGeneration(
23722440
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
23732441

23742442
genEnableDP = false;
@@ -2381,8 +2449,8 @@ TEST(targetTest, CacheStateContextDP)
23812449
contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
23822450
genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);
23832451

2384-
verfiyGeneration(
2452+
verifyGeneration(
23852453
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
2386-
verfiyGeneration(
2454+
verifyGeneration(
23872455
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
23882456
}

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,13 +1120,18 @@ def __init__(self,
11201120
reduce_output=not self.enable_attention_dp
11211121
and self.mapping.tp_size > 1)
11221122
else:
1123+
# When enable_attention_dp is True, TP reduction is skipped since each DP rank
1124+
# works on different batch elements. However, with CP > 1, attention is split
1125+
# across CP ranks for the SAME batch element, so reduction is still needed
1126+
# within the CP group.
1127+
needs_tp_reduce = not self.enable_attention_dp and self.mapping.tp_size > 1
1128+
needs_cp_reduce = mapping_with_cp is not None and mapping_with_cp.has_cp_helix()
11231129
self.self_attn = DeepseekV3Attention(
11241130
model_config,
11251131
layer_idx=layer_idx_for_attention,
11261132
aux_stream=aux_stream_dict[AuxStreamType.Attention],
11271133
mapping_with_cp=mapping_with_cp,
1128-
reduce_output=not self.enable_attention_dp
1129-
and self.mapping.tp_size > 1)
1134+
reduce_output=needs_tp_reduce or needs_cp_reduce)
11301135

11311136
self.fusion_config = EagerFusionConfig()
11321137
self.enable_fusion = os.environ.get(
@@ -1192,10 +1197,15 @@ def __init__(self,
11921197
eps=config.rms_norm_eps,
11931198
dtype=config.torch_dtype)
11941199

1200+
# When enable_attention_dp is True, we normally skip attention all-reduce since each
1201+
# DP rank works on different batch elements. However, with CP > 1, attention is split
1202+
# across CP ranks for the SAME batch element, so all-reduce is still needed.
1203+
has_cp = mapping_with_cp is not None and mapping_with_cp.cp_size > 1
1204+
can_skip_for_attention_dp = self.enable_attention_dp and not has_cp
11951205
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
11961206
or self.fusion_config.PRE_MLP_FUSION
11971207
or self.mapping.tp_size == 1
1198-
or self.enable_attention_dp)
1208+
or can_skip_for_attention_dp)
11991209

12001210
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
12011211
eps=config.rms_norm_eps,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -932,15 +932,42 @@ def __init__(
932932
requires_grad=False,
933933
)
934934

935-
mapping_o = Mapping(
936-
world_size=tp_size * pp_size * cp_size,
937-
tp_size=tp_size * cp_size,
938-
pp_size=pp_size,
939-
cp_size=1,
940-
rank=self.mapping.rank,
941-
gpus_per_node=self.mapping.gpus_per_node,
942-
enable_attention_dp=self.mapping.enable_attention_dp,
943-
)
935+
# For o_proj, we fold CP into TP so all CP ranks do all-reduce together.
936+
#
937+
# When enable_attention_dp=True:
938+
# - tp_size is forced to 1 (each DP group processes independently).
939+
# - Each DP group has its own independent all-reduce among CP ranks.
940+
# - We use pp_size to represent DP groups (original tp_size), so that
941+
# mapping_o.tp_group contains the correct actual world ranks for each DP group
942+
# Example with original tp=2, cp=2:
943+
# - DP group 0 (ranks 0, 1): tp_group = [0, 1]
944+
# - DP group 1 (ranks 2, 3): tp_group = [2, 3]
945+
#
946+
# When enable_attention_dp=False:
947+
# - All TP*CP ranks participate in a single all-reduce
948+
if self.mapping.enable_attention_dp and cp_size > 1:
949+
# Get original TP size (before it was forced to 1 for attention DP)
950+
original_tp_size = self.mapping.tp_size
951+
mapping_o = Mapping(
952+
world_size=original_tp_size * pp_size * cp_size,
953+
tp_size=cp_size, # Only CP ranks all-reduce together
954+
pp_size=pp_size *
955+
original_tp_size, # DP groups as separate PP groups
956+
cp_size=1,
957+
rank=self.mapping.rank, # Use actual world rank
958+
gpus_per_node=self.mapping.gpus_per_node,
959+
enable_attention_dp=self.mapping.enable_attention_dp,
960+
)
961+
else:
962+
mapping_o = Mapping(
963+
world_size=tp_size * pp_size * cp_size,
964+
tp_size=tp_size * cp_size,
965+
pp_size=pp_size,
966+
cp_size=1,
967+
rank=self.mapping.rank,
968+
gpus_per_node=self.mapping.gpus_per_node,
969+
enable_attention_dp=self.mapping.enable_attention_dp,
970+
)
944971
self.o_proj = Linear(
945972
self.num_key_value_heads * self.v_head_dim,
946973
self.hidden_size,

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,30 @@ def _fetch_new_requests_attention_tp(
359359
def _fetch_new_requests_attention_dp(
360360
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
361361
"""Handle attention DP request fetching with load balancing."""
362-
# Get active request counts across all ranks
362+
# Get active request counts across all ranks.
363363
all_ranks_num_active_requests = []
364364
all_ranks_num_active_tokens = []
365365
num_active_tokens = sum(
366366
[req.py_orig_prompt_len for req in activate_requests])
367367
responses_list = self.dist.tp_allgather(
368368
[len(activate_requests), num_active_tokens])
369+
370+
if self.dist.has_cp_helix:
371+
# When CP is enabled with Helix parallelism, tp_allgather returns one entry per rank,
372+
# but CP ranks within the same DP group (same tp_rank) handle the same requests with
373+
# different token portions (sequence is split across CP ranks).
374+
aggregated_responses = []
375+
for dp_group_idx in range(self.dist.tp_size):
376+
# Get all entries for this DP group (cp_size entries per group)
377+
group_start = dp_group_idx * self.dist.cp_size
378+
group_entries = responses_list[group_start:group_start +
379+
self.dist.cp_size]
380+
381+
# Sum the token counts across CP ranks (sequence is split)
382+
total_tokens = sum(entry[1] for entry in group_entries)
383+
aggregated_responses.append([group_entries[0][0], total_tokens])
384+
responses_list = aggregated_responses
385+
369386
for num_active_requests, num_active_tokens in responses_list:
370387
all_ranks_num_active_requests.append(num_active_requests)
371388
all_ranks_num_active_tokens.append(num_active_tokens)

0 commit comments

Comments
 (0)