Skip to content

Commit e9d168e

Browse files
committed
add bwd
1 parent 72dc153 commit e9d168e

File tree

3 files changed

+256
-102
lines changed

3 files changed

+256
-102
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,18 @@ def forward(self, x):
747747

748748
class FusedNormGateFunc(paddle.autograd.PyLayer):
749749
"""recompute of postnorm and gate"""
750+
_current_norm_output = None
751+
_current_invar = None
752+
753+
@classmethod
754+
def set_temporary_vars(cls, norm_output, invar):
755+
cls._current_norm_output = norm_output
756+
cls._current_invar = invar
757+
758+
@classmethod
759+
def clear_temporary_vars(cls):
760+
cls._current_norm_output = None
761+
cls._current_invar = None
750762

751763
@staticmethod
752764
def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
@@ -762,7 +774,12 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
762774
def backward(ctx, d_gate_logits, d_norm_output):
763775
x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor()
764776
# recompute rmsnorm
765-
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
777+
norm_output = FusedNormGateFunc._current_norm_output
778+
invar = FusedNormGateFunc._current_invar
779+
if norm_output is None or invar is None:
780+
raise RuntimeError("norm_output and invar must be set before backward!")
781+
782+
# norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
766783
d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad(
767784
cast_if_needed(norm_output, ctx.dtype),
768785
cast_if_needed(moe_gate_weight, ctx.dtype),
@@ -779,6 +796,16 @@ def backward(ctx, d_gate_logits, d_norm_output):
779796

780797
return dx, d_rms_norm_weight, d_moe_gate_weight
781798

799+
class TemporaryVarContext:
800+
def __init__(self, norm_output, invar):
801+
self.norm_output = norm_output
802+
self.invar = invar
803+
804+
def __enter__(self):
805+
FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar)
806+
807+
def __exit__(self, exc_type, exc_val, exc_tb):
808+
FusedNormGateFunc.clear_temporary_vars()
782809

783810
def balance_expert_assignment(n, m, k):
784811
assert k * n % m == 0

0 commit comments

Comments
 (0)