Skip to content

Commit fe70dd1

Browse files
committed
address comments from Yuxian
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 872167b commit fe70dd1

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

tensorrt_llm/_torch/device_mesh.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,17 @@ def build_mesh(self):
120120

121121
# Dimensions go from slowest-varying (outermost) to fastest-varying (innermost).
122122
# Layout: pp is outermost, then tp, then cp is innermost (consecutive).
123-
dims = ["pp", "tp"]
124-
shape = [self.pp_size, self.tp_size]
123+
dims = ["pp"]
124+
shape = [self.pp_size]
125125

126126
if self.moe_ep_size > 1:
127127
dims += ["moe_tp", "moe_ep"]
128128
shape += [self.moe_tp_size, self.moe_ep_size]
129129
else:
130+
dims += ["tp"]
131+
shape += [self.tp_size]
132+
133+
if self.cp_size > 1:
130134
dims += ["cp"]
131135
shape += [self.cp_size]
132136

tensorrt_llm/models/modeling_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,10 @@ def from_checkpoint(
741741
rank = config.mapping.rank
742742
if config.mapping.cp_size > 1:
743743
# cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt.
744-
tp_size = config.mapping.tp_size
745744
cp_size = config.mapping.cp_size
746-
rank = (rank % (tp_size * cp_size)) // cp_size + rank // (
747-
tp_size * cp_size) * tp_size
745+
# rank = pp_rank × tp_size × cp_size + tp_rank × cp_size + cp_rank.
746+
# rank // cp_size is equivalent to pp_rank × tp_size + tp_rank.
747+
rank = rank // cp_size
748748
weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')
749749

750750
assert os.path.isfile(weights_path)

0 commit comments

Comments
 (0)