Skip to content

Commit 23e092f

Browse files
rj42yaox12
andauthored
Fix: don't enter branch if mtp_num_layers == 0 (#2581)
Co-authored-by: Xin Yao <xiny@nvidia.com>
1 parent 1d462bd commit 23e092f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

megatron/core/models/gpt/gpt_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def _postprocess(
562562
if not self.post_process:
563563
return hidden_states
564564

565-
if self.config.mtp_num_layers is not None:
565+
# Skip when mtp_num_layers is None or 0
566+
if self.config.mtp_num_layers:
566567
mtp_labels = labels.clone()
567568
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
568569
hidden_states = hidden_states_list[0]

0 commit comments

Comments
 (0)