Skip to content

Commit c9ec40a

Browse files
authored
Add post norm recompute. (#10824)
* Add post norm recompute. * fix assert * fix bug
1 parent ae4f27b commit c9ec40a

File tree

5 files changed

+366
-125
lines changed

5 files changed

+366
-125
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
using_flex_token=False,
182182
use_dualpipev=False,
183183
send_mtp_embed=False,
184+
using_post_norm_recompute=False,
184185
recompute_fwd_gate_up=False,
185186
is_split_group_gemm=False,
186187
**kwargs,
@@ -233,6 +234,7 @@ def __init__(
233234
self.using_flex_token = using_flex_token
234235
self.use_dualpipev = use_dualpipev
235236
self.send_mtp_embed = send_mtp_embed
237+
self.using_post_norm_recompute = using_post_norm_recompute
236238
self.recompute_fwd_gate_up = recompute_fwd_gate_up
237239
self.is_split_group_gemm = is_split_group_gemm
238240

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 139 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from ..model_utils import PretrainedModel, register_base_model
7575
from ..moe_gate import PretrainedMoEGate
7676
from ..moe_layer import MoELayer
77-
from ..utils import device_guard
77+
from ..utils import cast_if_needed, device_guard
7878
from . import fp8_linear as linear_utils
7979
from .configuration import DeepseekV2Config
8080

@@ -737,6 +737,41 @@ def forward(self, x):
737737
return out
738738

739739

740+
class FusedNormGateFunc(paddle.autograd.PyLayer):
741+
"""recompute of postnorm and gate"""
742+
743+
@staticmethod
744+
def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
745+
ctx.dtype = paddle.float32
746+
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
747+
with paddle.amp.auto_cast(False):
748+
gate_logits = F.linear(cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype))
749+
750+
ctx.save_for_backward(x, rms_norm_weight, moe_gate_weight, eps)
751+
return gate_logits, norm_output
752+
753+
@staticmethod
754+
def backward(ctx, d_gate_logits, d_norm_output):
755+
x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor()
756+
# recompute rmsnorm
757+
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
758+
d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad(
759+
cast_if_needed(norm_output, ctx.dtype),
760+
cast_if_needed(moe_gate_weight, ctx.dtype),
761+
d_gate_logits,
762+
False,
763+
False,
764+
)
765+
d_norm_output_linear, d_moe_gate_weight = cast_if_needed(
766+
d_norm_output_linear, norm_output.dtype
767+
), cast_if_needed(d_moe_gate_weight, moe_gate_weight.dtype)
768+
769+
d_norm_output = d_norm_output + d_norm_output_linear
770+
dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, d_norm_output, eps)
771+
772+
return dx, d_rms_norm_weight, d_moe_gate_weight
773+
774+
740775
class FakeGate(paddle.autograd.PyLayer):
741776
@staticmethod
742777
def forward(ctx, hidden_states, weight):
@@ -756,7 +791,16 @@ def backward(ctx, grad_output):
756791

757792

758793
class MoEGate(PretrainedMoEGate):
759-
def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
794+
def __init__(
795+
self,
796+
config,
797+
num_experts,
798+
expert_hidden_size,
799+
using_post_norm_recompute=False,
800+
norm_weight=None,
801+
norm_eps=None,
802+
**kwargs
803+
):
760804
super().__init__(config, num_experts, expert_hidden_size, **kwargs)
761805
# [hidden_size, n_expert]
762806

@@ -771,6 +815,8 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
771815
)
772816

773817
self.config = config
818+
self.using_post_norm_recompute = using_post_norm_recompute
819+
774820
if config.topk_method == "noaux_tc":
775821
self.e_score_correction_bias = paddle.create_parameter(
776822
shape=[num_experts],
@@ -779,6 +825,10 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
779825
)
780826
self.e_score_correction_bias.is_distributed = True
781827

828+
if self.using_post_norm_recompute:
829+
assert norm_weight is not None and norm_eps is not None
830+
self.norm_weight = norm_weight
831+
self.norm_eps = norm_eps
782832
self.using_flex_token = False
783833

784834
def forward(self, hidden_states):
@@ -789,23 +839,33 @@ def forward(self, hidden_states):
789839
_, _, h_dim = hidden_states.shape
790840

791841
# compute gating score
792-
with paddle.amp.auto_cast(False):
793-
hidden_states = hidden_states.cast(self.weight.dtype)
794-
795-
if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate:
796-
logits = FakeGate.apply(hidden_states, self.weight)
797-
else:
798-
logits = F.linear(hidden_states, self.weight, None)
842+
if self.using_post_norm_recompute:
843+
logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps)
844+
else:
845+
with paddle.amp.auto_cast(False):
846+
hidden_states = hidden_states.cast(self.weight.dtype)
847+
if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate:
848+
logits = FakeGate.apply(hidden_states, self.weight)
849+
else:
850+
logits = F.linear(hidden_states, self.weight, None)
799851

800-
scores = self.gate_score_func(logits=logits)
801-
scores = scores.cast(paddle.float32)
852+
scores = self.gate_score_func(logits=logits)
853+
scores = scores.cast(paddle.float32)
802854

855+
# Compute all possible return values
803856
if self.using_flex_token:
804-
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores)
805-
return scores, routing_map, l_aux, l_zloss
857+
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(
858+
scores
859+
) # (scores, routing_map, exp_counts, l_aux, l_zloss)
860+
ret = (scores, routing_map, l_aux, l_zloss)
861+
else:
862+
ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss)
806863

807-
capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores)
808-
return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss
864+
# Append norm_out if needed
865+
if self.using_post_norm_recompute:
866+
ret = (*ret, norm_out)
867+
868+
return ret
809869

810870

811871
class AddAuxiliaryLoss(paddle.autograd.PyLayer):
@@ -833,9 +893,13 @@ class DeepseekV2MoE(MoELayer):
833893
A mixed expert module containing shared experts.
834894
"""
835895

836-
def __init__(self, config: DeepseekV2Config):
896+
def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
837897
assert config.tensor_parallel_degree <= 1, "tensor_parallel_degree should be 1"
838898

899+
self.using_post_norm_recompute = config.using_post_norm_recompute
900+
if self.using_post_norm_recompute:
901+
assert norm_weight is not None and norm_eps is not None
902+
839903
gate = MoEGate(
840904
config=config,
841905
num_experts=config.n_routed_experts,
@@ -847,6 +911,9 @@ def __init__(self, config: DeepseekV2Config):
847911
norm_topk_prob=config.norm_topk_prob,
848912
routed_scaling_factor=config.routed_scaling_factor,
849913
drop_tokens=False,
914+
using_post_norm_recompute=self.using_post_norm_recompute,
915+
norm_weight=norm_weight,
916+
norm_eps=norm_eps,
850917
)
851918
DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
852919

@@ -862,6 +929,7 @@ def __init__(self, config: DeepseekV2Config):
862929
gate=gate,
863930
capacity=2.0,
864931
moe_group="expert",
932+
using_post_norm_recompute=self.using_post_norm_recompute,
865933
)
866934

867935
moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
@@ -871,11 +939,44 @@ def __init__(self, config: DeepseekV2Config):
871939
self.alpha = config.aux_loss_alpha
872940
if config.n_shared_experts is not None:
873941
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
874-
self.shared_experts = DeepseekV2MLPClass(config=config, intermediate_size=intermediate_size, is_moe=False)
942+
if self.using_post_norm_recompute:
943+
assert DeepseekV2MLPClass is FP8Mlp
944+
self.shared_experts = DeepseekV2MLPClass(
945+
config=config,
946+
intermediate_size=intermediate_size,
947+
is_moe=False,
948+
using_post_norm_recompute=self.using_post_norm_recompute,
949+
norm_weight=norm_weight,
950+
norm_eps=norm_eps,
951+
)
952+
else:
953+
self.shared_experts = DeepseekV2MLPClass(
954+
config=config, intermediate_size=intermediate_size, is_moe=False
955+
)
875956

876957
def forward(self, hidden_states):
877-
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
878-
final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux)
958+
if self.using_post_norm_recompute:
959+
super().update_flex_token()
960+
if self.using_flex_token:
961+
probs, routing_map, l_aux, l_zloss, norm_out = self.router(hidden_states)
962+
final_hidden_states, l_aux, l_zloss = super().forward(
963+
norm_out, probs=probs, routing_map=routing_map, l_aux=l_aux, l_zloss=l_zloss
964+
)
965+
else:
966+
capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss, norm_out = self.gate(hidden_states)
967+
final_hidden_states, l_aux, l_zloss = super().forward(
968+
norm_out,
969+
capacity=capacity,
970+
topk_weight=topk_weight,
971+
topk_ids=topk_ids,
972+
token_priority=token_priority,
973+
l_aux=l_aux,
974+
l_zloss=l_zloss,
975+
)
976+
final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux)
977+
else:
978+
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
979+
final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux)
879980
return final_hidden_states
880981

881982
def post_process(self, hidden_states, final_hidden_states, l_aux):
@@ -1774,25 +1875,32 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
17741875
self.enable_recompute = False
17751876
self.layerwise_recompute = layerwise_recompute
17761877
self.recompute_granularity = config.recompute_granularity
1878+
self.using_post_norm_recompute = config.using_post_norm_recompute
17771879

17781880
self.hidden_size = config.hidden_size
17791881

17801882
self.self_attn = DeepseekV2Attention(config=config, layerwise_recompute=layerwise_recompute)
17811883

17821884
DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
17831885

1784-
self.mlp = (
1785-
DeepseekV2MoE(config)
1786-
if (
1787-
config.n_routed_experts is not None
1788-
and layer_idx >= config.first_k_dense_replace
1789-
and layer_idx % config.moe_layer_freq == 0
1790-
)
1791-
else DeepseekV2MLPClass(config)
1792-
)
17931886
self.input_layernorm = DeepseekV2RMSNorm(config)
17941887
self.post_attention_layernorm = DeepseekV2RMSNorm(config)
17951888

1889+
if (
1890+
config.n_routed_experts is not None
1891+
and layer_idx >= config.first_k_dense_replace
1892+
and layer_idx % config.moe_layer_freq == 0
1893+
):
1894+
self.mlp = (
1895+
DeepseekV2MoE(
1896+
config, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon
1897+
)
1898+
if config.using_post_norm_recompute
1899+
else DeepseekV2MoE(config)
1900+
)
1901+
else:
1902+
self.mlp = DeepseekV2MLPClass(config)
1903+
17961904
def forward(
17971905
self,
17981906
hidden_states: paddle.Tensor,
@@ -1871,7 +1979,9 @@ def forward(
18711979
# Fully Connected
18721980
residual = hidden_states
18731981

1874-
hidden_states = self.post_attention_layernorm(hidden_states)
1982+
if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)):
1983+
hidden_states = self.post_attention_layernorm(hidden_states)
1984+
18751985
hidden_states = self.mlp(hidden_states)
18761986
hidden_states = residual + hidden_states
18771987

0 commit comments

Comments
 (0)