74
74
from ..model_utils import PretrainedModel , register_base_model
75
75
from ..moe_gate import PretrainedMoEGate
76
76
from ..moe_layer import MoELayer
77
- from ..utils import device_guard
77
+ from ..utils import cast_if_needed , device_guard
78
78
from . import fp8_linear as linear_utils
79
79
from .configuration import DeepseekV2Config
80
80
@@ -737,6 +737,41 @@ def forward(self, x):
737
737
return out
738
738
739
739
740
+ class FusedNormGateFunc (paddle .autograd .PyLayer ):
741
+ """recompute of postnorm and gate"""
742
+
743
+ @staticmethod
744
+ def forward (ctx , x , rms_norm_weight , moe_gate_weight , eps ):
745
+ ctx .dtype = paddle .float32
746
+ norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
747
+ with paddle .amp .auto_cast (False ):
748
+ gate_logits = F .linear (cast_if_needed (norm_output , ctx .dtype ), cast_if_needed (moe_gate_weight , ctx .dtype ))
749
+
750
+ ctx .save_for_backward (x , rms_norm_weight , moe_gate_weight , eps )
751
+ return gate_logits , norm_output
752
+
753
+ @staticmethod
754
+ def backward (ctx , d_gate_logits , d_norm_output ):
755
+ x , rms_norm_weight , moe_gate_weight , eps = ctx .saved_tensor ()
756
+ # recompute rmsnorm
757
+ norm_output , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
758
+ d_norm_output_linear , d_moe_gate_weight = paddle ._C_ops .matmul_grad (
759
+ cast_if_needed (norm_output , ctx .dtype ),
760
+ cast_if_needed (moe_gate_weight , ctx .dtype ),
761
+ d_gate_logits ,
762
+ False ,
763
+ False ,
764
+ )
765
+ d_norm_output_linear , d_moe_gate_weight = cast_if_needed (
766
+ d_norm_output_linear , norm_output .dtype
767
+ ), cast_if_needed (d_moe_gate_weight , moe_gate_weight .dtype )
768
+
769
+ d_norm_output = d_norm_output + d_norm_output_linear
770
+ dx , d_rms_norm_weight = fused_ln .fused_rms_norm_grad_func (x , rms_norm_weight , invar , d_norm_output , eps )
771
+
772
+ return dx , d_rms_norm_weight , d_moe_gate_weight
773
+
774
+
740
775
class FakeGate (paddle .autograd .PyLayer ):
741
776
@staticmethod
742
777
def forward (ctx , hidden_states , weight ):
@@ -756,7 +791,16 @@ def backward(ctx, grad_output):
756
791
757
792
758
793
class MoEGate (PretrainedMoEGate ):
759
- def __init__ (self , config , num_experts , expert_hidden_size , ** kwargs ):
794
+ def __init__ (
795
+ self ,
796
+ config ,
797
+ num_experts ,
798
+ expert_hidden_size ,
799
+ using_post_norm_recompute = False ,
800
+ norm_weight = None ,
801
+ norm_eps = None ,
802
+ ** kwargs
803
+ ):
760
804
super ().__init__ (config , num_experts , expert_hidden_size , ** kwargs )
761
805
# [hidden_size, n_expert]
762
806
@@ -771,6 +815,8 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
771
815
)
772
816
773
817
self .config = config
818
+ self .using_post_norm_recompute = using_post_norm_recompute
819
+
774
820
if config .topk_method == "noaux_tc" :
775
821
self .e_score_correction_bias = paddle .create_parameter (
776
822
shape = [num_experts ],
@@ -779,6 +825,10 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
779
825
)
780
826
self .e_score_correction_bias .is_distributed = True
781
827
828
+ if self .using_post_norm_recompute :
829
+ assert norm_weight is not None and norm_eps is not None
830
+ self .norm_weight = norm_weight
831
+ self .norm_eps = norm_eps
782
832
self .using_flex_token = False
783
833
784
834
def forward (self , hidden_states ):
@@ -789,23 +839,33 @@ def forward(self, hidden_states):
789
839
_ , _ , h_dim = hidden_states .shape
790
840
791
841
# compute gating score
792
- with paddle .amp .auto_cast (False ):
793
- hidden_states = hidden_states .cast (self .weight .dtype )
794
-
795
- if hasattr (self .config , "using_fake_gate" ) and self .config .using_fake_gate :
796
- logits = FakeGate .apply (hidden_states , self .weight )
797
- else :
798
- logits = F .linear (hidden_states , self .weight , None )
842
+ if self .using_post_norm_recompute :
843
+ logits , norm_out = FusedNormGateFunc .apply (hidden_states , self .norm_weight , self .weight , self .norm_eps )
844
+ else :
845
+ with paddle .amp .auto_cast (False ):
846
+ hidden_states = hidden_states .cast (self .weight .dtype )
847
+ if hasattr (self .config , "using_fake_gate" ) and self .config .using_fake_gate :
848
+ logits = FakeGate .apply (hidden_states , self .weight )
849
+ else :
850
+ logits = F .linear (hidden_states , self .weight , None )
799
851
800
- scores = self .gate_score_func (logits = logits )
801
- scores = scores .cast (paddle .float32 )
852
+ scores = self .gate_score_func (logits = logits )
853
+ scores = scores .cast (paddle .float32 )
802
854
855
+ # Compute all possible return values
803
856
if self .using_flex_token :
804
- scores , routing_map , exp_counts , l_aux , l_zloss = self .topkgating_nodrop (scores )
805
- return scores , routing_map , l_aux , l_zloss
857
+ scores , routing_map , exp_counts , l_aux , l_zloss = self .topkgating_nodrop (
858
+ scores
859
+ ) # (scores, routing_map, exp_counts, l_aux, l_zloss)
860
+ ret = (scores , routing_map , l_aux , l_zloss )
861
+ else :
862
+ ret = self .topkgating (scores ) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss)
806
863
807
- capacity , combine_weights , dispatch_mask , exp_counts , l_aux , l_zloss = self .topkgating (scores )
808
- return capacity , combine_weights , dispatch_mask , exp_counts , l_aux , l_zloss
864
+ # Append norm_out if needed
865
+ if self .using_post_norm_recompute :
866
+ ret = (* ret , norm_out )
867
+
868
+ return ret
809
869
810
870
811
871
class AddAuxiliaryLoss (paddle .autograd .PyLayer ):
@@ -833,9 +893,13 @@ class DeepseekV2MoE(MoELayer):
833
893
A mixed expert module containing shared experts.
834
894
"""
835
895
836
- def __init__ (self , config : DeepseekV2Config ):
896
+ def __init__ (self , config : DeepseekV2Config , norm_weight = None , norm_eps = None ):
837
897
assert config .tensor_parallel_degree <= 1 , "tensor_parallel_degree should be 1"
838
898
899
+ self .using_post_norm_recompute = config .using_post_norm_recompute
900
+ if self .using_post_norm_recompute :
901
+ assert norm_weight is not None and norm_eps is not None
902
+
839
903
gate = MoEGate (
840
904
config = config ,
841
905
num_experts = config .n_routed_experts ,
@@ -847,6 +911,9 @@ def __init__(self, config: DeepseekV2Config):
847
911
norm_topk_prob = config .norm_topk_prob ,
848
912
routed_scaling_factor = config .routed_scaling_factor ,
849
913
drop_tokens = False ,
914
+ using_post_norm_recompute = self .using_post_norm_recompute ,
915
+ norm_weight = norm_weight ,
916
+ norm_eps = norm_eps ,
850
917
)
851
918
DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
852
919
@@ -862,6 +929,7 @@ def __init__(self, config: DeepseekV2Config):
862
929
gate = gate ,
863
930
capacity = 2.0 ,
864
931
moe_group = "expert" ,
932
+ using_post_norm_recompute = self .using_post_norm_recompute ,
865
933
)
866
934
867
935
moe_grad_group = fleet .get_hybrid_communicate_group ().expert_grad_comm_group
@@ -871,11 +939,44 @@ def __init__(self, config: DeepseekV2Config):
871
939
self .alpha = config .aux_loss_alpha
872
940
if config .n_shared_experts is not None :
873
941
intermediate_size = config .moe_intermediate_size * config .n_shared_experts
874
- self .shared_experts = DeepseekV2MLPClass (config = config , intermediate_size = intermediate_size , is_moe = False )
942
+ if self .using_post_norm_recompute :
943
+ assert DeepseekV2MLPClass is FP8Mlp
944
+ self .shared_experts = DeepseekV2MLPClass (
945
+ config = config ,
946
+ intermediate_size = intermediate_size ,
947
+ is_moe = False ,
948
+ using_post_norm_recompute = self .using_post_norm_recompute ,
949
+ norm_weight = norm_weight ,
950
+ norm_eps = norm_eps ,
951
+ )
952
+ else :
953
+ self .shared_experts = DeepseekV2MLPClass (
954
+ config = config , intermediate_size = intermediate_size , is_moe = False
955
+ )
875
956
876
957
def forward (self , hidden_states ):
877
- final_hidden_states , l_aux , l_zloss = super ().forward (hidden_states )
878
- final_hidden_states = self .post_process (hidden_states , final_hidden_states , l_aux )
958
+ if self .using_post_norm_recompute :
959
+ super ().update_flex_token ()
960
+ if self .using_flex_token :
961
+ probs , routing_map , l_aux , l_zloss , norm_out = self .router (hidden_states )
962
+ final_hidden_states , l_aux , l_zloss = super ().forward (
963
+ norm_out , probs = probs , routing_map = routing_map , l_aux = l_aux , l_zloss = l_zloss
964
+ )
965
+ else :
966
+ capacity , topk_weight , topk_ids , token_priority , l_aux , l_zloss , norm_out = self .gate (hidden_states )
967
+ final_hidden_states , l_aux , l_zloss = super ().forward (
968
+ norm_out ,
969
+ capacity = capacity ,
970
+ topk_weight = topk_weight ,
971
+ topk_ids = topk_ids ,
972
+ token_priority = token_priority ,
973
+ l_aux = l_aux ,
974
+ l_zloss = l_zloss ,
975
+ )
976
+ final_hidden_states = self .post_process (hidden_states , final_hidden_states , l_aux )
977
+ else :
978
+ final_hidden_states , l_aux , l_zloss = super ().forward (hidden_states )
979
+ final_hidden_states = self .post_process (hidden_states , final_hidden_states , l_aux )
879
980
return final_hidden_states
880
981
881
982
def post_process (self , hidden_states , final_hidden_states , l_aux ):
@@ -1774,25 +1875,32 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
1774
1875
self .enable_recompute = False
1775
1876
self .layerwise_recompute = layerwise_recompute
1776
1877
self .recompute_granularity = config .recompute_granularity
1878
+ self .using_post_norm_recompute = config .using_post_norm_recompute
1777
1879
1778
1880
self .hidden_size = config .hidden_size
1779
1881
1780
1882
self .self_attn = DeepseekV2Attention (config = config , layerwise_recompute = layerwise_recompute )
1781
1883
1782
1884
DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
1783
1885
1784
- self .mlp = (
1785
- DeepseekV2MoE (config )
1786
- if (
1787
- config .n_routed_experts is not None
1788
- and layer_idx >= config .first_k_dense_replace
1789
- and layer_idx % config .moe_layer_freq == 0
1790
- )
1791
- else DeepseekV2MLPClass (config )
1792
- )
1793
1886
self .input_layernorm = DeepseekV2RMSNorm (config )
1794
1887
self .post_attention_layernorm = DeepseekV2RMSNorm (config )
1795
1888
1889
+ if (
1890
+ config .n_routed_experts is not None
1891
+ and layer_idx >= config .first_k_dense_replace
1892
+ and layer_idx % config .moe_layer_freq == 0
1893
+ ):
1894
+ self .mlp = (
1895
+ DeepseekV2MoE (
1896
+ config , self .post_attention_layernorm .weight , self .post_attention_layernorm .variance_epsilon
1897
+ )
1898
+ if config .using_post_norm_recompute
1899
+ else DeepseekV2MoE (config )
1900
+ )
1901
+ else :
1902
+ self .mlp = DeepseekV2MLPClass (config )
1903
+
1796
1904
def forward (
1797
1905
self ,
1798
1906
hidden_states : paddle .Tensor ,
@@ -1871,7 +1979,9 @@ def forward(
1871
1979
# Fully Connected
1872
1980
residual = hidden_states
1873
1981
1874
- hidden_states = self .post_attention_layernorm (hidden_states )
1982
+ if not (self .using_post_norm_recompute and isinstance (self .mlp , DeepseekV2MoE )):
1983
+ hidden_states = self .post_attention_layernorm (hidden_states )
1984
+
1875
1985
hidden_states = self .mlp (hidden_states )
1876
1986
hidden_states = residual + hidden_states
1877
1987
0 commit comments