@@ -747,6 +747,18 @@ def forward(self, x):
747
747
748
748
class FusedNormGateFunc (paddle .autograd .PyLayer ):
749
749
"""recompute of postnorm and gate"""
750
+ _current_norm_output = None
751
+ _current_invar = None
752
+
753
+ @classmethod
754
+ def set_temporary_vars (cls , norm_output , invar ):
755
+ cls ._current_norm_output = norm_output
756
+ cls ._current_invar = invar
757
+
758
+ @classmethod
759
+ def clear_temporary_vars (cls ):
760
+ cls ._current_norm_output = None
761
+ cls ._current_invar = None
750
762
751
763
@staticmethod
752
764
def forward (ctx , x , rms_norm_weight , moe_gate_weight , eps ):
@@ -762,7 +774,12 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
762
774
def backward (ctx , d_gate_logits , d_norm_output ):
763
775
x , rms_norm_weight , moe_gate_weight , eps = ctx .saved_tensor ()
764
776
# recompute rmsnorm
765
- norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
777
+ norm_output = FusedNormGateFunc ._current_norm_output
778
+ invar = FusedNormGateFunc ._current_invar
779
+ if norm_output is None or invar is None :
780
+ raise RuntimeError ("norm_output and invar must be set before backward!" )
781
+
782
+ # norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
766
783
d_norm_output_linear , d_moe_gate_weight = paddle ._C_ops .matmul_grad (
767
784
cast_if_needed (norm_output , ctx .dtype ),
768
785
cast_if_needed (moe_gate_weight , ctx .dtype ),
@@ -779,6 +796,16 @@ def backward(ctx, d_gate_logits, d_norm_output):
779
796
780
797
return dx , d_rms_norm_weight , d_moe_gate_weight
781
798
799
+ class TemporaryVarContext :
800
+ def __init__ (self , norm_output , invar ):
801
+ self .norm_output = norm_output
802
+ self .invar = invar
803
+
804
+ def __enter__ (self ):
805
+ FusedNormGateFunc .set_temporary_vars (self .norm_output , self .invar )
806
+
807
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
808
+ FusedNormGateFunc .clear_temporary_vars ()
782
809
783
810
def balance_expert_assignment (n , m , k ):
784
811
assert k * n % m == 0
0 commit comments