@@ -565,20 +565,31 @@ def topkgating_nodrop(self, gates: paddle.Tensor):
565
565
top_gate , top_idx = self ._topk_noaux_tc (
566
566
gates , k = self .top_k , n_group = self .n_group , topk_group = self .topk_group
567
567
)
568
+
568
569
# norm gate to sum 1
569
- if self .top_k > 1 and self .norm_topk_prob :
570
- denominator = top_gate .sum (axis = - 1 , keepdim = True ) + 1e-20
571
- top_gate = top_gate / denominator
572
- top_gate = top_gate * self .routed_scaling_factor
570
+ # if self.top_k > 1 and self.norm_topk_prob:
571
+ # denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20
572
+ # top_gate = top_gate / denominator
573
+ # top_gate = top_gate * self.routed_scaling_factor
573
574
574
575
# get topk mask
575
576
mask = paddle .zeros_like (gates ).put_along_axis (top_idx , paddle .ones ([], dtype = "float32" ), axis = 1 )
576
577
578
+ gates_masked = gates * mask
579
+ # if self.training:
580
+ gates_s = paddle .sum (gates_masked , axis = - 1 , keepdim = True )
581
+ denom_s = paddle .clip (gates_s , min = paddle .finfo (gates_masked .dtype ).eps )
582
+
583
+ if self .norm_topk_prob :
584
+ gates_masked = gates_masked / denom_s
585
+
586
+ gates_masked *= self .routed_scaling_factor
587
+
577
588
if hasattr (self .config , "seq_aux" ) and self .config .seq_aux :
578
589
l_aux = self ._cal_seq_aux_loss (gates_ori , self .top_k , top_idx )
579
590
else :
580
591
l_aux = self ._cal_aux_loss (gates , mask )
581
592
582
593
exp_counts = paddle .sum (mask .cast (paddle .int64 ), axis = 0 )
583
- topk_masked_gates = paddle .zeros_like (gates ).put_along_axis (top_idx , top_gate , axis = 1 )
584
- return topk_masked_gates , mask , exp_counts , l_aux , l_zloss
594
+ # topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
595
+ return gates_masked , mask , exp_counts , l_aux , l_zloss
0 commit comments