We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a8408b4 commit 4e50cceCopy full SHA for 4e50cce
colossalai/shardformer/modeling/mistral.py
@@ -683,12 +683,7 @@ def forward(
683
)
684
685
hidden_states = outputs[0]
686
- if self.config.pretraining_tp > 1:
687
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
688
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
689
- logits = torch.cat(logits, dim=-1)
690
- else:
691
- logits = self.lm_head(hidden_states)
+ logits = self.lm_head(hidden_states)
692
logits = logits.float()
693
694
loss = None
0 commit comments