@@ -752,6 +752,18 @@ def forward(self, x):
752
752
753
753
class FusedNormGateFunc (paddle .autograd .PyLayer ):
754
754
"""recompute of postnorm and gate"""
755
+ _current_norm_output = None
756
+ _current_invar = None
757
+
758
+ @classmethod
759
+ def set_temporary_vars (cls , norm_output , invar ):
760
+ FusedNormGateFunc ._current_norm_output = norm_output
761
+ FusedNormGateFunc ._current_invar = invar
762
+
763
+ @classmethod
764
+ def clear_temporary_vars (cls ):
765
+ FusedNormGateFunc ._current_norm_output = None
766
+ FusedNormGateFunc ._current_invar = None
755
767
756
768
@staticmethod
757
769
def forward (ctx , x , rms_norm_weight , moe_gate_weight , eps ):
@@ -767,7 +779,10 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
767
779
def backward (ctx , d_gate_logits , d_norm_output ):
768
780
x , rms_norm_weight , moe_gate_weight , eps = ctx .saved_tensor ()
769
781
# recompute rmsnorm
770
- norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
782
+ norm_output = FusedNormGateFunc ._current_norm_output
783
+ invar = FusedNormGateFunc ._current_invar
784
+ if norm_output is None or invar is None :
785
+ norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
771
786
d_norm_output_linear , d_moe_gate_weight = paddle ._C_ops .matmul_grad (
772
787
cast_if_needed (norm_output , ctx .dtype ),
773
788
cast_if_needed (moe_gate_weight , ctx .dtype ),
@@ -784,6 +799,16 @@ def backward(ctx, d_gate_logits, d_norm_output):
784
799
785
800
return dx , d_rms_norm_weight , d_moe_gate_weight
786
801
802
+ class TemporaryVarContext :
803
+ def __init__ (self , norm_output , invar ):
804
+ self .norm_output = norm_output
805
+ self .invar = invar
806
+
807
+ def __enter__ (self ):
808
+ FusedNormGateFunc .set_temporary_vars (self .norm_output , self .invar )
809
+
810
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
811
+ FusedNormGateFunc .clear_temporary_vars ()
787
812
788
813
def balance_expert_assignment (n , m , k ):
789
814
assert k * n % m == 0
0 commit comments