27
27
ScheduleNode ,
28
28
SharedLayerDesc ,
29
29
)
30
- from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import (
31
- WeightGradStore
32
- )
30
+ from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import WeightGradStore
33
31
34
32
try :
35
33
from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import EventStore
@@ -721,17 +719,12 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
721
719
ret = (inputs_embeds_mtp_grad , * ret ) if self .send_mtp_embed else ret
722
720
return ret
723
721
724
- def mlp_backward_dw (self ):
725
- self .fp8_fusion_moe_node .mlp_node .backward_dw ()
726
-
727
722
def mlp_backward (self , output_grad ):
728
723
if self .send_mtp_embed :
729
724
inputs_embeds_mtp_grad , hidden_states_grad , residual_grad , l_aux_grad , hidden_states_out_grad = output_grad
730
725
else :
731
726
hidden_states_grad , residual_grad , l_aux_grad , hidden_states_out_grad = output_grad
732
- hs_dispatched_grad , dispatched_probs_grad = self .fp8_fusion_moe_node .mlp_node .backward (
733
- hidden_states_out_grad , with_dw = False
734
- )
727
+ hs_dispatched_grad , dispatched_probs_grad = self .fp8_fusion_moe_node .mlp_node .backward (hidden_states_out_grad )
735
728
736
729
ret = (hidden_states_grad , residual_grad , l_aux_grad , hs_dispatched_grad , dispatched_probs_grad )
737
730
ret = (inputs_embeds_mtp_grad , * ret ) if self .send_mtp_embed else ret
@@ -809,7 +802,6 @@ def backward(self, output_grad=None, scaler=None):
809
802
output_grad = self .mlp_backward (output_grad )
810
803
# todo(phlrain): overlap here
811
804
output_grad = self .dispatch_backward (output_grad )
812
- self .mlp_backward_dw ()
813
805
output_grad = self .attn_backward (output_grad )
814
806
return output_grad
815
807
@@ -853,7 +845,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
853
845
854
846
combine_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
855
847
paddle .base .core .nvprof_nvtx_push ("mlp_backward_dx" )
848
+ assert WeightGradStore .funcs_queue .empty ()
849
+ WeightGradStore .enabled = True
856
850
output_grad = self .backward_node .mlp_backward (output_grad )
851
+ WeightGradStore .enabled = False
852
+ WeightGradStore .flush ()
857
853
paddle .base .core .nvprof_nvtx_pop ()
858
854
859
855
output_grad_event = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
@@ -889,7 +885,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
889
885
final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
890
886
paddle .base .core .nvprof_nvtx_pop ()
891
887
892
-
893
888
final_out_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
894
889
895
890
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
@@ -922,13 +917,13 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
922
917
paddle .base .core .nvprof_nvtx_pop ()
923
918
924
919
dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
925
-
920
+
926
921
paddle .base .core .nvprof_nvtx_push ("attn_backward" )
927
922
assert WeightGradStore .funcs_queue .empty ()
928
923
WeightGradStore .enabled = True
929
924
output_grad = self .backward_node .attn_backward (output_grad )
930
925
event_to_wait = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
931
-
926
+
932
927
if EventStore is not None :
933
928
EventStore .set (event_to_wait )
934
929
@@ -937,16 +932,12 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
937
932
WeightGradStore .pop ()
938
933
assert WeightGradStore .funcs_queue .empty ()
939
934
940
- WeightGradStore .enabled = False
941
- WeightGradStore .flush ()
942
- WeightGradStore .pop ()
943
- assert WeightGradStore .funcs_queue .empty ()
944
935
paddle .base .core .nvprof_nvtx_pop ()
945
936
946
937
# residual add
947
938
if pp_stream is None :
948
939
combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
949
-
940
+
950
941
final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
951
942
if final_out .shape [- 1 ] != combine_fwd_out .shape [- 1 ]:
952
943
final_out [:, :, : combine_fwd_out .shape [- 1 ]] += combine_fwd_out # 直接广播并相加
0 commit comments