@@ -752,18 +752,6 @@ def forward(self, x):
752
752
753
753
class FusedNormGateFunc (paddle .autograd .PyLayer ):
754
754
"""recompute of postnorm and gate"""
755
- _current_norm_output = None
756
- _current_invar = None
757
-
758
- @classmethod
759
- def set_temporary_vars (cls , norm_output , invar ):
760
- FusedNormGateFunc ._current_norm_output = norm_output
761
- FusedNormGateFunc ._current_invar = invar
762
-
763
- @classmethod
764
- def clear_temporary_vars (cls ):
765
- FusedNormGateFunc ._current_norm_output = None
766
- FusedNormGateFunc ._current_invar = None
767
755
768
756
@staticmethod
769
757
def forward (ctx , x , rms_norm_weight , moe_gate_weight , eps ):
@@ -779,10 +767,7 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
779
767
def backward (ctx , d_gate_logits , d_norm_output ):
780
768
x , rms_norm_weight , moe_gate_weight , eps = ctx .saved_tensor ()
781
769
# recompute rmsnorm
782
- norm_output = FusedNormGateFunc ._current_norm_output
783
- invar = FusedNormGateFunc ._current_invar
784
- if norm_output is None or invar is None :
785
- norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
770
+ norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
786
771
d_norm_output_linear , d_moe_gate_weight = paddle ._C_ops .matmul_grad (
787
772
cast_if_needed (norm_output , ctx .dtype ),
788
773
cast_if_needed (moe_gate_weight , ctx .dtype ),
@@ -799,16 +784,6 @@ def backward(ctx, d_gate_logits, d_norm_output):
799
784
800
785
return dx , d_rms_norm_weight , d_moe_gate_weight
801
786
802
- class TemporaryVarContext :
803
- def __init__ (self , norm_output , invar ):
804
- self .norm_output = norm_output
805
- self .invar = invar
806
-
807
- def __enter__ (self ):
808
- FusedNormGateFunc .set_temporary_vars (self .norm_output , self .invar )
809
-
810
- def __exit__ (self , exc_type , exc_val , exc_tb ):
811
- FusedNormGateFunc .clear_temporary_vars ()
812
787
813
788
def balance_expert_assignment (n , m , k ):
814
789
assert k * n % m == 0
0 commit comments