Skip to content

Commit 3bf41e7

Browse files
committed
[TRTLLM-9465][fix] Swap TP-CP grouping order
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 4a1b742 commit 3bf41e7

File tree

10 files changed

+76
-58
lines changed

10 files changed

+76
-58
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ bool CacheFormatter::needSendCache(
154154
return true;
155155
}
156156

157-
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
157+
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
158+
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
158159
int selfTpRankInDpGroup = selfTpRank;
159160
if (selfConfig.getParallelConfig().mEnableAttentionDP)
160161
{

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
6060
bool MLACacheFormatter::needSendCache(
6161
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
6262
{
63-
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
63+
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
64+
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
6465

6566
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
6667
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ TargetRanksInfo TargetRanksInfoForDP(
107107
auto const peerCPNum = peerParConfig.mContextParallelism;
108108
auto const selfCPNum = selfParConfig.mContextParallelism;
109109

110-
auto const selfTPRank = selfRank % selfTPNum;
110+
auto const selfCPRank = selfRank % selfCPNum;
111+
auto const selfTPRank = (selfRank % (selfTPNum * selfCPNum)) / selfCPNum;
111112
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
112-
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;
113113

114114
int peerPPRankStart = 0;
115115
int mDomainPPSize = 1;
@@ -211,7 +211,9 @@ TargetRanksInfo TargetRanksInfoForDP(
211211
{
212212
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
213213
{
214-
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
214+
// Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank
215+
// where i=tpRank, j=cpRank, k=ppRank
216+
int irank = (k * peerTPNum * peerCPNum) + (i * peerCPNum) + j;
215217
retRanks.push_back(irank);
216218
}
217219
}

tensorrt_llm/_torch/device_mesh.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,17 @@ def build_mesh(self):
118118
"DeviceMesh creation requested but torch.distributed process group "
119119
"has not been initialised.")
120120

121-
dims = ["cp", "pp"]
122-
shape = [self.cp_size, self.pp_size]
121+
# Dimensions go from slowest-varying (outermost) to fastest-varying (innermost)
122+
# Layout: pp is outermost, then tp, then cp is innermost (consecutive)
123+
dims = ["pp", "tp"]
124+
shape = [self.pp_size, self.tp_size]
123125

124126
if self.moe_ep_size > 1:
125127
dims += ["moe_tp", "moe_ep"]
126128
shape += [self.moe_tp_size, self.moe_ep_size]
127129
else:
128-
dims += ["tp"]
129-
shape += [self.tp_size]
130+
dims += ["cp"]
131+
shape += [self.cp_size]
130132

131133
cls.device_mesh = init_device_mesh(
132134
"cuda",

tensorrt_llm/mapping.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -292,18 +292,16 @@ def has_cp(self):
292292
return self.cp_size > 1
293293

294294
def prev_cp_rank(self):
295-
p = self.rank - self.tp_size
296-
if p // (self.tp_size * self.cp_size) < self.rank // (self.tp_size *
297-
self.cp_size):
298-
return p + self.tp_size * self.cp_size
299-
return p
295+
# cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group
296+
if self.cp_rank == 0:
297+
return self.rank + self.cp_size - 1
298+
return self.rank - 1
300299

301300
def next_cp_rank(self):
302-
p = self.rank + self.tp_size
303-
if p // (self.tp_size * self.cp_size) > self.rank // (self.tp_size *
304-
self.cp_size):
305-
return p - self.tp_size * self.cp_size
306-
return p
301+
# cp ranks are consecutive, so next is rank + 1 with wraparound within cp group
302+
if self.cp_rank == self.cp_size - 1:
303+
return self.rank - self.cp_size + 1
304+
return self.rank + 1
307305

308306
def has_moe_cluster(self):
309307
return self.moe_cluster_size > 1
@@ -378,17 +376,17 @@ class Mapping(MappingBase):
378376
379377
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
380378
381-
2 tp groups:
379+
4 cp groups:
382380
383-
- [0, 1, 2, 3]
384-
- [4, 5, 6, 7]
381+
- [0, 1]
382+
- [2, 3]
383+
- [4, 5]
384+
- [6, 7]
385385
386-
4 cp groups:
386+
2 tp groups:
387387
388-
- [0, 4]
389-
- [1, 5]
390-
- [2, 6]
391-
- [3, 7]
388+
- [0, 2, 4, 6]
389+
- [1, 3, 5, 7]
392390
393391
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
394392
@@ -437,23 +435,23 @@ class Mapping(MappingBase):
437435
438436
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
439437
440-
4 tp groups:
438+
4 cp groups:
441439
- [0, 1]
442440
- [2, 3]
443441
- [4, 5]
444442
- [6, 7]
445443
444+
4 tp groups:
445+
- [0, 2]
446+
- [1, 3]
447+
- [4, 6]
448+
- [5, 7]
449+
446450
4 pp groups:
447451
- [0, 4]
448452
- [1, 5]
449453
- [2, 6]
450454
- [3, 7]
451-
452-
4 cp groups:
453-
- [0, 2]
454-
- [1, 3]
455-
- [4, 6]
456-
- [5, 7]
457455
"""
458456

459457
def __new__(cls, *args, **kwargs):
@@ -551,23 +549,23 @@ def __init__(self, *args, **kwargs):
551549

552550
@property
553551
def tp_rank(self) -> int:
554-
return self.rank % self.tp_size
552+
return self.rank % (self.tp_size * self.cp_size) // self.cp_size
555553

556554
@property
557555
def pp_rank(self) -> int:
558556
return self.rank // (self.tp_size * self.cp_size)
559557

560558
@property
561559
def cp_rank(self) -> int:
562-
return self.rank % (self.tp_size * self.cp_size) // self.tp_size
560+
return self.rank % self.cp_size
563561

564562
@property
565563
def tp_group(self) -> List[int]:
566564
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
567565

568566
@property
569567
def pp_group(self) -> List[int]:
570-
return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank]
568+
return self.pp_groups[self.tp_rank * self.cp_size + self.cp_rank]
571569

572570
@property
573571
def cp_group(self) -> List[int]:
@@ -598,20 +596,20 @@ def _init_parallel_groups(self):
598596
ranks = range(i, self.world_size, self.tp_size * self.cp_size)
599597
self.pp_groups.append(list(ranks))
600598

601-
# init cp group
599+
# init cp group (consecutive ranks within each tp slice)
602600
for i in range(self.pp_size):
603601
for j in range(self.tp_size):
604-
ranks = range(i * self.tp_size * self.cp_size + j,
605-
(i + 1) * self.tp_size * self.cp_size + j,
606-
self.tp_size)
602+
ranks = range(
603+
i * self.tp_size * self.cp_size + j * self.cp_size,
604+
i * self.tp_size * self.cp_size + (j + 1) * self.cp_size)
607605
self.cp_groups.append(list(ranks))
608606

609-
# init tp group
607+
# init tp group (interleaved ranks with stride of cp_size)
610608
for i in range(self.pp_size):
611609
for j in range(self.cp_size):
612-
ranks = range(
613-
i * self.tp_size * self.cp_size + j * self.tp_size,
614-
i * self.tp_size * self.cp_size + (j + 1) * self.tp_size)
610+
ranks = range(i * self.tp_size * self.cp_size + j,
611+
(i + 1) * self.tp_size * self.cp_size + j,
612+
self.cp_size)
615613
self.tp_groups.append(list(ranks))
616614

617615
# init moe tp group

tensorrt_llm/models/modeling_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,11 @@ def from_checkpoint(
740740

741741
rank = config.mapping.rank
742742
if config.mapping.cp_size > 1:
743-
# tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
743+
# cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
744744
tp_size = config.mapping.tp_size
745745
cp_size = config.mapping.cp_size
746-
rank = rank % tp_size + rank // (tp_size * cp_size) * tp_size
746+
rank = (rank % (tp_size * cp_size)) // cp_size + rank // (
747+
tp_size * cp_size) * tp_size
747748
weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')
748749

749750
assert os.path.isfile(weights_path)

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,8 +872,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
872872
task.evaluate(llm)
873873

874874
@pytest.mark.skip_less_device(8)
875-
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
876-
ids=["pp1tp2cp2", "pp2tp1cp2"])
875+
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
876+
(2, 1, 2)],
877+
ids=["pp1tp1cp4", "pp1tp2cp2", "pp2tp1cp2"])
877878
@pytest.mark.parametrize("cuda_graph_config", [
878879
None,
879880
{
@@ -912,6 +913,7 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
912913
"backend": "UCX",
913914
"max_tokens_in_buffer": 8192,
914915
},
916+
# "print_iter_log": True,
915917
}
916918
gen_server_config = {
917919
"tensor_parallel_size": gen_tp,
@@ -931,6 +933,7 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
931933
"backend": "UCX",
932934
"max_tokens_in_buffer": 8192,
933935
},
936+
# "print_iter_log": True,
934937
}
935938
disaggregated_server_config = {
936939
"hostname": "localhost",

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
540540
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
541541
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
542542
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
543+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2]
544+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp1tp2cp2]
545+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
546+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2]
547+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp1tp2cp2]
548+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
543549
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
544550
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
545551
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ l0_dgx_b200:
6767
orchestrator: mpi
6868
tests:
6969
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
70+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
71+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
7072
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
7173
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
7274
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
@@ -94,6 +96,8 @@ l0_dgx_b200:
9496
orchestrator: mpi
9597
tests:
9698
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
99+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
100+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
97101
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
98102
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
99103
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)

tests/unittest/others/test_mapping.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,27 @@ def test_mapping(self):
5757
self.assertEqual(len(m.tp_groups), 4)
5858
self.assertEqual(len(m.pp_groups), 4)
5959
self.assertEqual(len(m.cp_groups), 4)
60-
self.assertEqual(m.tp_group, [2, 3])
60+
self.assertEqual(m.tp_group, [1, 3])
6161
self.assertEqual(m.pp_group, [3, 7])
62-
self.assertEqual(m.cp_group, [1, 3])
62+
self.assertEqual(m.cp_group, [2, 3])
6363
self.assertTrue(m.is_first_pp_rank())
6464
self.assertFalse(m.is_last_pp_rank())
6565
self.assertFalse(m.is_first_cp_rank())
6666
self.assertTrue(m.is_last_cp_rank())
6767
self.assertEqual(m.prev_pp_rank(), 7)
6868
self.assertEqual(m.next_pp_rank(), 7)
69-
self.assertEqual(m.prev_cp_rank(), 1)
70-
self.assertEqual(m.next_cp_rank(), 1)
69+
self.assertEqual(m.prev_cp_rank(), 2)
70+
self.assertEqual(m.next_cp_rank(), 2)
7171

7272
m = Mapping(world_size=16, rank=9, tp_size=2, pp_size=2, cp_size=4)
73-
self.assertEqual(m.tp_group, [8, 9])
73+
self.assertEqual(m.tp_group, [9, 13])
7474
self.assertEqual(m.pp_group, [1, 9])
75-
self.assertEqual(m.cp_group, [9, 11, 13, 15])
75+
self.assertEqual(m.cp_group, [8, 9, 10, 11])
7676
self.assertFalse(m.is_first_pp_rank())
7777
self.assertTrue(m.is_last_pp_rank())
78-
self.assertTrue(m.is_first_cp_rank())
78+
self.assertFalse(m.is_first_cp_rank())
7979
self.assertFalse(m.is_last_cp_rank())
8080
self.assertEqual(m.prev_pp_rank(), 1)
8181
self.assertEqual(m.next_pp_rank(), 1)
82-
self.assertEqual(m.prev_cp_rank(), 15)
83-
self.assertEqual(m.next_cp_rank(), 11)
82+
self.assertEqual(m.prev_cp_rank(), 8)
83+
self.assertEqual(m.next_cp_rank(), 10)

0 commit comments

Comments
 (0)