Skip to content

Commit a3f57dc

Browse files
committed
1 parent e91f55a commit a3f57dc

File tree

3 files changed

+118
-258
lines changed

3 files changed

+118
-258
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -752,18 +752,6 @@ def forward(self, x):
752752

753753
class FusedNormGateFunc(paddle.autograd.PyLayer):
754754
"""recompute of postnorm and gate"""
755-
_current_norm_output = None
756-
_current_invar = None
757-
758-
@classmethod
759-
def set_temporary_vars(cls, norm_output, invar):
760-
FusedNormGateFunc._current_norm_output = norm_output
761-
FusedNormGateFunc._current_invar = invar
762-
763-
@classmethod
764-
def clear_temporary_vars(cls):
765-
FusedNormGateFunc._current_norm_output = None
766-
FusedNormGateFunc._current_invar = None
767755

768756
@staticmethod
769757
def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
@@ -779,10 +767,7 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
779767
def backward(ctx, d_gate_logits, d_norm_output):
780768
x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor()
781769
# recompute rmsnorm
782-
norm_output = FusedNormGateFunc._current_norm_output
783-
invar = FusedNormGateFunc._current_invar
784-
if norm_output is None or invar is None:
785-
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
770+
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
786771
d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad(
787772
cast_if_needed(norm_output, ctx.dtype),
788773
cast_if_needed(moe_gate_weight, ctx.dtype),
@@ -799,16 +784,6 @@ def backward(ctx, d_gate_logits, d_norm_output):
799784

800785
return dx, d_rms_norm_weight, d_moe_gate_weight
801786

802-
class TemporaryVarContext:
803-
def __init__(self, norm_output, invar):
804-
self.norm_output = norm_output
805-
self.invar = invar
806-
807-
def __enter__(self):
808-
FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar)
809-
810-
def __exit__(self, exc_type, exc_val, exc_tb):
811-
FusedNormGateFunc.clear_temporary_vars()
812787

813788
def balance_expert_assignment(n, m, k):
814789
assert k * n % m == 0

0 commit comments

Comments
 (0)