Skip to content

Commit 059776d

Browse files
committed
Fix DogeCDMoE router_gate logits reshape to (2, batch*seq_len, num_keys) before top-k selection
1 parent 84b985e commit 059776d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/modeling/modeling_doge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def forward(
285285
bsz, seq_len, _ = hidden_states.shape
286286

287287
# get routing logits with router gate
288-
router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
288+
router_logits = self.router_gate(hidden_states).view(bsz * seq_len, 2, -1).transpose(0, 1)
289289

290290
# get experts with the highest routing logits
291291
(scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)

0 commit comments

Comments
 (0)