Skip to content

Commit 17cc8a4

Browse files
authored
fix gate prob (#10972)
* fix gate prob * remove useless code
1 parent f83311a commit 17cc8a4

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

paddlenlp/transformers/moe_gate.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -565,20 +565,31 @@ def topkgating_nodrop(self, gates: paddle.Tensor):
565565
top_gate, top_idx = self._topk_noaux_tc(
566566
gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group
567567
)
568+
568569
# norm gate to sum 1
569-
if self.top_k > 1 and self.norm_topk_prob:
570-
denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20
571-
top_gate = top_gate / denominator
572-
top_gate = top_gate * self.routed_scaling_factor
570+
# if self.top_k > 1 and self.norm_topk_prob:
571+
# denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20
572+
# top_gate = top_gate / denominator
573+
# top_gate = top_gate * self.routed_scaling_factor
573574

574575
# get topk mask
575576
mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=1)
576577

578+
gates_masked = gates * mask
579+
# if self.training:
580+
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
581+
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
582+
583+
if self.norm_topk_prob:
584+
gates_masked = gates_masked / denom_s
585+
586+
gates_masked *= self.routed_scaling_factor
587+
577588
if hasattr(self.config, "seq_aux") and self.config.seq_aux:
578589
l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx)
579590
else:
580591
l_aux = self._cal_aux_loss(gates, mask)
581592

582593
exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0)
583-
topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
584-
return topk_masked_gates, mask, exp_counts, l_aux, l_zloss
594+
# topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
595+
return gates_masked, mask, exp_counts, l_aux, l_zloss

0 commit comments

Comments
 (0)