File tree Expand file tree Collapse file tree 2 files changed +9
-5
lines changed
Expand file tree Collapse file tree 2 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -120,13 +120,17 @@ def build_mesh(self):
120120
121121 # Dimensions go from slowest-varying (outermost) to fastest-varying (innermost).
122122 # Layout: pp is outermost, then tp, then cp is innermost (consecutive).
123- dims = ["pp" , "tp" ]
124- shape = [self .pp_size , self . tp_size ]
123+ dims = ["pp" ]
124+ shape = [self .pp_size ]
125125
126126 if self .moe_ep_size > 1 :
127127 dims += ["moe_tp" , "moe_ep" ]
128128 shape += [self .moe_tp_size , self .moe_ep_size ]
129129 else :
130+ dims += ["tp" ]
131+ shape += [self .tp_size ]
132+
133+ if self .cp_size > 1 :
130134 dims += ["cp" ]
131135 shape += [self .cp_size ]
132136
Original file line number Diff line number Diff line change @@ -741,10 +741,10 @@ def from_checkpoint(
741741 rank = config .mapping .rank
742742 if config .mapping .cp_size > 1 :
743743 # cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt.
744- tp_size = config .mapping .tp_size
745744 cp_size = config .mapping .cp_size
746- rank = (rank % (tp_size * cp_size )) // cp_size + rank // (
747- tp_size * cp_size ) * tp_size
745+ # rank = pp_rank × tp_size × cp_size + tp_rank × cp_size + cp_rank.
746+ # rank // cp_size is equivalent to pp_rank × tp_size + tp_rank.
747+ rank = rank // cp_size
748748 weights_path = os .path .join (ckpt_dir , f'rank{ rank } .safetensors' )
749749
750750 assert os .path .isfile (weights_path )
You can’t perform that action at this time.
0 commit comments