Skip to content

Commit 6d58c7c

Browse files
authored
fix aadiff (#10874)
1 parent b8faf66 commit 6d58c7c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

paddlenlp/transformers/moe_gate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ def _topk_noaux_tc(
302302

303303
assert self.e_score_correction_bias is not None, "e_score_correction_bias is None"
304304
scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0)
305-
group_scores = (
306-
scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
307-
) # fmt:skip [n, n_group]
305+
reshape_tmp_rst = scores_for_choice.reshape([bsz_seq_len, self.n_group, -1])
306+
top_k = min(reshape_tmp_rst.shape[2], 2)
307+
group_scores = reshape_tmp_rst.topk(top_k, axis=-1)[0].sum(axis=-1) # fmt:skip [n, n_group]
308308
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
309309
group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.ones([], dtype="float32"), axis=-1) # fmt:skip
310310
score_mask = (

0 commit comments

Comments
 (0)