Skip to content

Commit 8cb1712

Browse files
save args for rms_norm output (#10930)
* add bwd * fix bug * Update modeling.py * Update modeling_pp.py * Update modeling_pp.py * Update fp8_utils.py --------- Co-authored-by: zhangbo9674 <[email protected]>
1 parent 114bee1 commit 8cb1712

File tree

3 files changed

+258
-118
lines changed

3 files changed

+258
-118
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,18 @@ def forward(self, x):
752752

753753
class FusedNormGateFunc(paddle.autograd.PyLayer):
754754
"""recompute of postnorm and gate"""
755+
_current_norm_output = None
756+
_current_invar = None
757+
758+
@classmethod
759+
def set_temporary_vars(cls, norm_output, invar):
760+
FusedNormGateFunc._current_norm_output = norm_output
761+
FusedNormGateFunc._current_invar = invar
762+
763+
@classmethod
764+
def clear_temporary_vars(cls):
765+
FusedNormGateFunc._current_norm_output = None
766+
FusedNormGateFunc._current_invar = None
755767

756768
@staticmethod
757769
def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
@@ -767,7 +779,10 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
767779
def backward(ctx, d_gate_logits, d_norm_output):
768780
x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor()
769781
# recompute rmsnorm
770-
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
782+
norm_output = FusedNormGateFunc._current_norm_output
783+
invar = FusedNormGateFunc._current_invar
784+
if norm_output is None or invar is None:
785+
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
771786
d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad(
772787
cast_if_needed(norm_output, ctx.dtype),
773788
cast_if_needed(moe_gate_weight, ctx.dtype),
@@ -784,6 +799,16 @@ def backward(ctx, d_gate_logits, d_norm_output):
784799

785800
return dx, d_rms_norm_weight, d_moe_gate_weight
786801

802+
class TemporaryVarContext:
803+
def __init__(self, norm_output, invar):
804+
self.norm_output = norm_output
805+
self.invar = invar
806+
807+
def __enter__(self):
808+
FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar)
809+
810+
def __exit__(self, exc_type, exc_val, exc_tb):
811+
FusedNormGateFunc.clear_temporary_vars()
787812

788813
def balance_expert_assignment(n, m, k):
789814
assert k * n % m == 0

0 commit comments

Comments
 (0)