Skip to content

Commit 4e50cce

Browse files
committed
fix the mistral model
1 parent a8408b4 commit 4e50cce

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

colossalai/shardformer/modeling/mistral.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -683,12 +683,7 @@ def forward(
683683
)
684684

685685
hidden_states = outputs[0]
686-
if self.config.pretraining_tp > 1:
687-
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
688-
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
689-
logits = torch.cat(logits, dim=-1)
690-
else:
691-
logits = self.lm_head(hidden_states)
686+
logits = self.lm_head(hidden_states)
692687
logits = logits.float()
693688

694689
loss = None

0 commit comments

Comments
 (0)