7474from ..model_utils import PretrainedModel , register_base_model
7575from ..moe_gate import PretrainedMoEGate
7676from ..moe_layer import MoELayer
77- from ..utils import device_guard
77+ from ..utils import cast_if_needed , device_guard
7878from . import fp8_linear as linear_utils
7979from .configuration import DeepseekV2Config
8080
@@ -737,6 +737,41 @@ def forward(self, x):
737737 return out
738738
739739
740+ class FusedNormGateFunc (paddle .autograd .PyLayer ):
741+ """recompute of postnorm and gate"""
742+
743+ @staticmethod
744+ def forward (ctx , x , rms_norm_weight , moe_gate_weight , eps ):
745+ ctx .dtype = paddle .float32
746+ norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
747+ with paddle .amp .auto_cast (False ):
748+ gate_logits = F .linear (cast_if_needed (norm_output , ctx .dtype ), cast_if_needed (moe_gate_weight , ctx .dtype ))
749+
750+ ctx .save_for_backward (x , rms_norm_weight , moe_gate_weight , eps )
751+ return gate_logits , norm_output
752+
753+ @staticmethod
754+ def backward (ctx , d_gate_logits , d_norm_output ):
755+ x , rms_norm_weight , moe_gate_weight , eps = ctx .saved_tensor ()
756+ # recompute rmsnorm
757+ norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
758+ d_norm_output_linear , d_moe_gate_weight = paddle ._C_ops .matmul_grad (
759+ cast_if_needed (norm_output , ctx .dtype ),
760+ cast_if_needed (moe_gate_weight , ctx .dtype ),
761+ d_gate_logits ,
762+ False ,
763+ False ,
764+ )
765+ d_norm_output_linear , d_moe_gate_weight = cast_if_needed (
766+ d_norm_output_linear , norm_output .dtype
767+ ), cast_if_needed (d_moe_gate_weight , moe_gate_weight .dtype )
768+
769+ d_norm_output = d_norm_output + d_norm_output_linear
770+ dx , d_rms_norm_weight = fused_ln .fused_rms_norm_grad_func (x , rms_norm_weight , invar , d_norm_output , eps )
771+
772+ return dx , d_rms_norm_weight , d_moe_gate_weight
773+
774+
740775class FakeGate (paddle .autograd .PyLayer ):
741776 @staticmethod
742777 def forward (ctx , hidden_states , weight ):
@@ -756,7 +791,16 @@ def backward(ctx, grad_output):
756791
757792
758793class MoEGate (PretrainedMoEGate ):
759- def __init__ (self , config , num_experts , expert_hidden_size , ** kwargs ):
794+ def __init__ (
795+ self ,
796+ config ,
797+ num_experts ,
798+ expert_hidden_size ,
799+ using_post_norm_recompute = False ,
800+ norm_weight = None ,
801+ norm_eps = None ,
802+ ** kwargs
803+ ):
760804 super ().__init__ (config , num_experts , expert_hidden_size , ** kwargs )
761805 # [hidden_size, n_expert]
762806
@@ -771,6 +815,8 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
771815 )
772816
773817 self .config = config
818+ self .using_post_norm_recompute = using_post_norm_recompute
819+
774820 if config .topk_method == "noaux_tc" :
775821 self .e_score_correction_bias = paddle .create_parameter (
776822 shape = [num_experts ],
@@ -779,6 +825,10 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
779825 )
780826 self .e_score_correction_bias .is_distributed = True
781827
828+ if self .using_post_norm_recompute :
829+ assert norm_weight is not None and norm_eps is not None
830+ self .norm_weight = norm_weight
831+ self .norm_eps = norm_eps
782832 self .using_flex_token = False
783833
784834 def forward (self , hidden_states ):
@@ -789,23 +839,33 @@ def forward(self, hidden_states):
789839 _ , _ , h_dim = hidden_states .shape
790840
791841 # compute gating score
792- with paddle .amp .auto_cast (False ):
793- hidden_states = hidden_states .cast (self .weight .dtype )
794-
795- if hasattr (self .config , "using_fake_gate" ) and self .config .using_fake_gate :
796- logits = FakeGate .apply (hidden_states , self .weight )
797- else :
798- logits = F .linear (hidden_states , self .weight , None )
842+ if self .using_post_norm_recompute :
843+ logits , norm_out = FusedNormGateFunc .apply (hidden_states , self .norm_weight , self .weight , self .norm_eps )
844+ else :
845+ with paddle .amp .auto_cast (False ):
846+ hidden_states = hidden_states .cast (self .weight .dtype )
847+ if hasattr (self .config , "using_fake_gate" ) and self .config .using_fake_gate :
848+ logits = FakeGate .apply (hidden_states , self .weight )
849+ else :
850+ logits = F .linear (hidden_states , self .weight , None )
799851
800- scores = self .gate_score_func (logits = logits )
801- scores = scores .cast (paddle .float32 )
852+ scores = self .gate_score_func (logits = logits )
853+ scores = scores .cast (paddle .float32 )
802854
855+ # Compute all possible return values
803856 if self .using_flex_token :
804- scores , routing_map , exp_counts , l_aux , l_zloss = self .topkgating_nodrop (scores )
805- return scores , routing_map , l_aux , l_zloss
857+ scores , routing_map , exp_counts , l_aux , l_zloss = self .topkgating_nodrop (
858+ scores
859+ ) # (scores, routing_map, exp_counts, l_aux, l_zloss)
860+ ret = (scores , routing_map , l_aux , l_zloss )
861+ else :
862+ ret = self .topkgating (scores ) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss)
806863
807- capacity , combine_weights , dispatch_mask , exp_counts , l_aux , l_zloss = self .topkgating (scores )
808- return capacity , combine_weights , dispatch_mask , exp_counts , l_aux , l_zloss
864+ # Append norm_out if needed
865+ if self .using_post_norm_recompute :
866+ ret = (* ret , norm_out )
867+
868+ return ret
809869
810870
811871class AddAuxiliaryLoss (paddle .autograd .PyLayer ):
@@ -833,9 +893,13 @@ class DeepseekV2MoE(MoELayer):
833893 A mixed expert module containing shared experts.
834894 """
835895
836- def __init__ (self , config : DeepseekV2Config ):
896+ def __init__ (self , config : DeepseekV2Config , norm_weight = None , norm_eps = None ):
837897 assert config .tensor_parallel_degree <= 1 , "tensor_parallel_degree should be 1"
838898
899+ self .using_post_norm_recompute = config .using_post_norm_recompute
900+ if self .using_post_norm_recompute :
901+ assert norm_weight is not None and norm_eps is not None
902+
839903 gate = MoEGate (
840904 config = config ,
841905 num_experts = config .n_routed_experts ,
@@ -847,6 +911,9 @@ def __init__(self, config: DeepseekV2Config):
847911 norm_topk_prob = config .norm_topk_prob ,
848912 routed_scaling_factor = config .routed_scaling_factor ,
849913 drop_tokens = False ,
914+ using_post_norm_recompute = self .using_post_norm_recompute ,
915+ norm_weight = norm_weight ,
916+ norm_eps = norm_eps ,
850917 )
851918 DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
852919
@@ -862,6 +929,7 @@ def __init__(self, config: DeepseekV2Config):
862929 gate = gate ,
863930 capacity = 2.0 ,
864931 moe_group = "expert" ,
932+ using_post_norm_recompute = self .using_post_norm_recompute ,
865933 )
866934
867935 moe_grad_group = fleet .get_hybrid_communicate_group ().expert_grad_comm_group
@@ -871,11 +939,44 @@ def __init__(self, config: DeepseekV2Config):
871939 self .alpha = config .aux_loss_alpha
872940 if config .n_shared_experts is not None :
873941 intermediate_size = config .moe_intermediate_size * config .n_shared_experts
874- self .shared_experts = DeepseekV2MLPClass (config = config , intermediate_size = intermediate_size , is_moe = False )
942+ if self .using_post_norm_recompute :
943+ assert DeepseekV2MLPClass is FP8Mlp
944+ self .shared_experts = DeepseekV2MLPClass (
945+ config = config ,
946+ intermediate_size = intermediate_size ,
947+ is_moe = False ,
948+ using_post_norm_recompute = self .using_post_norm_recompute ,
949+ norm_weight = norm_weight ,
950+ norm_eps = norm_eps ,
951+ )
952+ else :
953+ self .shared_experts = DeepseekV2MLPClass (
954+ config = config , intermediate_size = intermediate_size , is_moe = False
955+ )
875956
876957 def forward (self , hidden_states ):
877- final_hidden_states , l_aux , l_zloss = super ().forward (hidden_states )
878- final_hidden_states = self .post_process (hidden_states , final_hidden_states , l_aux )
958+ if self .using_post_norm_recompute :
959+ super ().update_flex_token ()
960+ if self .using_flex_token :
961+ probs , routing_map , l_aux , l_zloss , norm_out = self .router (hidden_states )
962+ final_hidden_states , l_aux , l_zloss = super ().forward (
963+ norm_out , probs = probs , routing_map = routing_map , l_aux = l_aux , l_zloss = l_zloss
964+ )
965+ else :
966+ capacity , topk_weight , topk_ids , token_priority , l_aux , l_zloss , norm_out = self .gate (hidden_states )
967+ final_hidden_states , l_aux , l_zloss = super ().forward (
968+ norm_out ,
969+ capacity = capacity ,
970+ topk_weight = topk_weight ,
971+ topk_ids = topk_ids ,
972+ token_priority = token_priority ,
973+ l_aux = l_aux ,
974+ l_zloss = l_zloss ,
975+ )
976+ final_hidden_states = self .post_process (hidden_states , final_hidden_states , l_aux )
977+ else :
978+ final_hidden_states , l_aux , l_zloss = super ().forward (hidden_states )
979+ final_hidden_states = self .post_process (hidden_states , final_hidden_states , l_aux )
879980 return final_hidden_states
880981
881982 def post_process (self , hidden_states , final_hidden_states , l_aux ):
@@ -1774,25 +1875,32 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
17741875 self .enable_recompute = False
17751876 self .layerwise_recompute = layerwise_recompute
17761877 self .recompute_granularity = config .recompute_granularity
1878+ self .using_post_norm_recompute = config .using_post_norm_recompute
17771879
17781880 self .hidden_size = config .hidden_size
17791881
17801882 self .self_attn = DeepseekV2Attention (config = config , layerwise_recompute = layerwise_recompute )
17811883
17821884 DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
17831885
1784- self .mlp = (
1785- DeepseekV2MoE (config )
1786- if (
1787- config .n_routed_experts is not None
1788- and layer_idx >= config .first_k_dense_replace
1789- and layer_idx % config .moe_layer_freq == 0
1790- )
1791- else DeepseekV2MLPClass (config )
1792- )
17931886 self .input_layernorm = DeepseekV2RMSNorm (config )
17941887 self .post_attention_layernorm = DeepseekV2RMSNorm (config )
17951888
1889+ if (
1890+ config .n_routed_experts is not None
1891+ and layer_idx >= config .first_k_dense_replace
1892+ and layer_idx % config .moe_layer_freq == 0
1893+ ):
1894+ self .mlp = (
1895+ DeepseekV2MoE (
1896+ config , self .post_attention_layernorm .weight , self .post_attention_layernorm .variance_epsilon
1897+ )
1898+ if config .using_post_norm_recompute
1899+ else DeepseekV2MoE (config )
1900+ )
1901+ else :
1902+ self .mlp = DeepseekV2MLPClass (config )
1903+
17961904 def forward (
17971905 self ,
17981906 hidden_states : paddle .Tensor ,
@@ -1871,7 +1979,9 @@ def forward(
18711979 # Fully Connected
18721980 residual = hidden_states
18731981
1874- hidden_states = self .post_attention_layernorm (hidden_states )
1982+ if not (self .using_post_norm_recompute and isinstance (self .mlp , DeepseekV2MoE )):
1983+ hidden_states = self .post_attention_layernorm (hidden_states )
1984+
18751985 hidden_states = self .mlp (hidden_states )
18761986 hidden_states = residual + hidden_states
18771987
0 commit comments