We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1d462bd commit 23e092fCopy full SHA for 23e092f
megatron/core/models/gpt/gpt_model.py
@@ -562,7 +562,8 @@ def _postprocess(
562
if not self.post_process:
563
return hidden_states
564
565
- if self.config.mtp_num_layers is not None:
+ # Skip when mtp_num_layers is None or 0
566
+ if self.config.mtp_num_layers:
567
mtp_labels = labels.clone()
568
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
569
hidden_states = hidden_states_list[0]
0 commit comments