Skip to content

Commit 5fda479

Browse files
committed
save initial changes
1 parent 74832a1 commit 5fda479

File tree

4 files changed

+52
-51
lines changed

4 files changed

+52
-51
lines changed

tensorrt_llm/_torch/device_mesh.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,17 @@ def build_mesh(self):
118118
"DeviceMesh creation requested but torch.distributed process group "
119119
"has not been initialised.")
120120

121-
dims = ["cp", "pp"]
122-
shape = [self.cp_size, self.pp_size]
121+
# Dimensions go from slowest-varying (outermost) to fastest-varying (innermost)
122+
# Layout: pp is outermost, then tp, then cp is innermost (consecutive)
123+
dims = ["pp", "tp"]
124+
shape = [self.pp_size, self.tp_size]
123125

124126
if self.moe_ep_size > 1:
125127
dims += ["moe_tp", "moe_ep"]
126128
shape += [self.moe_tp_size, self.moe_ep_size]
127129
else:
128-
dims += ["tp"]
129-
shape += [self.tp_size]
130+
dims += ["cp"]
131+
shape += [self.cp_size]
130132

131133
cls.device_mesh = init_device_mesh(
132134
"cuda",

tensorrt_llm/mapping.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -292,18 +292,16 @@ def has_cp(self):
292292
return self.cp_size > 1
293293

294294
def prev_cp_rank(self):
295-
p = self.rank - self.tp_size
296-
if p // (self.tp_size * self.cp_size) < self.rank // (self.tp_size *
297-
self.cp_size):
298-
return p + self.tp_size * self.cp_size
299-
return p
295+
# cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group
296+
if self.cp_rank == 0:
297+
return self.rank + self.cp_size - 1
298+
return self.rank - 1
300299

301300
def next_cp_rank(self):
302-
p = self.rank + self.tp_size
303-
if p // (self.tp_size * self.cp_size) > self.rank // (self.tp_size *
304-
self.cp_size):
305-
return p - self.tp_size * self.cp_size
306-
return p
301+
# cp ranks are consecutive, so next is rank + 1 with wraparound within cp group
302+
if self.cp_rank == self.cp_size - 1:
303+
return self.rank - self.cp_size + 1
304+
return self.rank + 1
307305

308306
def has_moe_cluster(self):
309307
return self.moe_cluster_size > 1
@@ -378,17 +376,17 @@ class Mapping(MappingBase):
378376
379377
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
380378
381-
2 tp groups:
379+
4 cp groups:
382380
383-
- [0, 1, 2, 3]
384-
- [4, 5, 6, 7]
381+
- [0, 1]
382+
- [2, 3]
383+
- [4, 5]
384+
- [6, 7]
385385
386-
4 cp groups:
386+
2 tp groups:
387387
388-
- [0, 4]
389-
- [1, 5]
390-
- [2, 6]
391-
- [3, 7]
388+
- [0, 2, 4, 6]
389+
- [1, 3, 5, 7]
392390
393391
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
394392
@@ -437,23 +435,23 @@ class Mapping(MappingBase):
437435
438436
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
439437
440-
4 tp groups:
438+
4 cp groups:
441439
- [0, 1]
442440
- [2, 3]
443441
- [4, 5]
444442
- [6, 7]
445443
444+
4 tp groups:
445+
- [0, 2]
446+
- [1, 3]
447+
- [4, 6]
448+
- [5, 7]
449+
446450
4 pp groups:
447451
- [0, 4]
448452
- [1, 5]
449453
- [2, 6]
450454
- [3, 7]
451-
452-
4 cp groups:
453-
- [0, 2]
454-
- [1, 3]
455-
- [4, 6]
456-
- [5, 7]
457455
"""
458456

459457
def __new__(cls, *args, **kwargs):
@@ -551,23 +549,23 @@ def __init__(self, *args, **kwargs):
551549

552550
@property
553551
def tp_rank(self) -> int:
554-
return self.rank % self.tp_size
552+
return self.rank % (self.tp_size * self.cp_size) // self.cp_size
555553

556554
@property
557555
def pp_rank(self) -> int:
558556
return self.rank // (self.tp_size * self.cp_size)
559557

560558
@property
561559
def cp_rank(self) -> int:
562-
return self.rank % (self.tp_size * self.cp_size) // self.tp_size
560+
return self.rank % self.cp_size
563561

564562
@property
565563
def tp_group(self) -> List[int]:
566564
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
567565

568566
@property
569567
def pp_group(self) -> List[int]:
570-
return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank]
568+
return self.pp_groups[self.tp_rank * self.cp_size + self.cp_rank]
571569

572570
@property
573571
def cp_group(self) -> List[int]:
@@ -598,20 +596,20 @@ def _init_parallel_groups(self):
598596
ranks = range(i, self.world_size, self.tp_size * self.cp_size)
599597
self.pp_groups.append(list(ranks))
600598

601-
# init cp group
599+
# init cp group (consecutive ranks within each tp slice)
602600
for i in range(self.pp_size):
603601
for j in range(self.tp_size):
604-
ranks = range(i * self.tp_size * self.cp_size + j,
605-
(i + 1) * self.tp_size * self.cp_size + j,
606-
self.tp_size)
602+
ranks = range(
603+
i * self.tp_size * self.cp_size + j * self.cp_size,
604+
i * self.tp_size * self.cp_size + (j + 1) * self.cp_size)
607605
self.cp_groups.append(list(ranks))
608606

609-
# init tp group
607+
# init tp group (interleaved ranks with stride of cp_size)
610608
for i in range(self.pp_size):
611609
for j in range(self.cp_size):
612-
ranks = range(
613-
i * self.tp_size * self.cp_size + j * self.tp_size,
614-
i * self.tp_size * self.cp_size + (j + 1) * self.tp_size)
610+
ranks = range(i * self.tp_size * self.cp_size + j,
611+
(i + 1) * self.tp_size * self.cp_size + j,
612+
self.cp_size)
615613
self.tp_groups.append(list(ranks))
616614

617615
# init moe tp group

tensorrt_llm/models/modeling_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,11 @@ def from_checkpoint(
740740

741741
rank = config.mapping.rank
742742
if config.mapping.cp_size > 1:
743-
# tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
743+
# cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
744744
tp_size = config.mapping.tp_size
745745
cp_size = config.mapping.cp_size
746-
rank = rank % tp_size + rank // (tp_size * cp_size) * tp_size
746+
rank = (rank % (tp_size * cp_size)) // cp_size + rank // (
747+
tp_size * cp_size) * tp_size
747748
weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')
748749

749750
assert os.path.isfile(weights_path)

tests/unittest/others/test_mapping.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,27 @@ def test_mapping(self):
5757
self.assertEqual(len(m.tp_groups), 4)
5858
self.assertEqual(len(m.pp_groups), 4)
5959
self.assertEqual(len(m.cp_groups), 4)
60-
self.assertEqual(m.tp_group, [2, 3])
60+
self.assertEqual(m.tp_group, [1, 3])
6161
self.assertEqual(m.pp_group, [3, 7])
62-
self.assertEqual(m.cp_group, [1, 3])
62+
self.assertEqual(m.cp_group, [2, 3])
6363
self.assertTrue(m.is_first_pp_rank())
6464
self.assertFalse(m.is_last_pp_rank())
6565
self.assertFalse(m.is_first_cp_rank())
6666
self.assertTrue(m.is_last_cp_rank())
6767
self.assertEqual(m.prev_pp_rank(), 7)
6868
self.assertEqual(m.next_pp_rank(), 7)
69-
self.assertEqual(m.prev_cp_rank(), 1)
70-
self.assertEqual(m.next_cp_rank(), 1)
69+
self.assertEqual(m.prev_cp_rank(), 2)
70+
self.assertEqual(m.next_cp_rank(), 2)
7171

7272
m = Mapping(world_size=16, rank=9, tp_size=2, pp_size=2, cp_size=4)
73-
self.assertEqual(m.tp_group, [8, 9])
73+
self.assertEqual(m.tp_group, [9, 13])
7474
self.assertEqual(m.pp_group, [1, 9])
75-
self.assertEqual(m.cp_group, [9, 11, 13, 15])
75+
self.assertEqual(m.cp_group, [8, 9, 10, 11])
7676
self.assertFalse(m.is_first_pp_rank())
7777
self.assertTrue(m.is_last_pp_rank())
78-
self.assertTrue(m.is_first_cp_rank())
78+
self.assertFalse(m.is_first_cp_rank())
7979
self.assertFalse(m.is_last_cp_rank())
8080
self.assertEqual(m.prev_pp_rank(), 1)
8181
self.assertEqual(m.next_pp_rank(), 1)
82-
self.assertEqual(m.prev_cp_rank(), 15)
83-
self.assertEqual(m.next_cp_rank(), 11)
82+
self.assertEqual(m.prev_cp_rank(), 8)
83+
self.assertEqual(m.next_cp_rank(), 10)

0 commit comments

Comments
 (0)