@@ -820,13 +820,17 @@ def forward_backward(self, inputs, output_grad):
820
820
821
821
paddle .base .core .nvprof_nvtx_push ("combine_backward" )
822
822
output_grad = self .backward_node .combine_backward (output_grad , async_finish = True )
823
+ # get combine event
824
+ combine_backward_event = deep_ep .get_event_from_comm_stream ( self .backward_node .moe_group .id )
823
825
paddle .base .core .nvprof_nvtx_pop ()
826
+
824
827
paddle .base .core .nvprof_nvtx_push ("attn_forward" )
825
828
inputs = self .forward_node .attn_forward (inputs )
826
829
paddle .base .core .nvprof_nvtx_pop ()
827
-
828
- calc_stream_wait (self .backward_node .moe_group .id )
829
830
attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
831
+
832
+
833
+ combine_backward_event .calc_stream_wait ( self .backward_node .moe_group .id )
830
834
paddle .base .core .nvprof_nvtx_push ("mlp_backward_dx" )
831
835
output_grad = self .backward_node .mlp_backward (output_grad )
832
836
paddle .base .core .nvprof_nvtx_pop ()
@@ -835,28 +839,35 @@ def forward_backward(self, inputs, output_grad):
835
839
inputs , previous_event = attn_compute_event , async_finish = True , allocate_on_comm_stream = True
836
840
)
837
841
paddle .base .core .nvprof_nvtx_pop ()
842
+ dispatch_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
843
+
838
844
paddle .base .core .nvprof_nvtx_push ("dispatch_backward" )
839
845
output_grad = self .backward_node .dispatch_backward (output_grad , async_finish = True )
840
846
paddle .base .core .nvprof_nvtx_pop ()
847
+ # get dispatch backward event
848
+ dispatch_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
841
849
842
850
paddle .base .core .nvprof_nvtx_push ("dispatch_backward_dw" )
843
851
self .backward_node .mlp_backward_dw ()
844
852
paddle .base .core .nvprof_nvtx_pop ()
845
853
846
- calc_stream_wait (self .forward_node .moe_group .id )
854
+ dispatch_forward_event . calc_stream_wait ( self .forward_node .moe_group .id )
847
855
paddle .base .core .nvprof_nvtx_push ("mlp_forward" )
848
856
inputs = self .forward_node .mlp_forward (inputs )
849
857
paddle .base .core .nvprof_nvtx_pop ()
850
858
851
- calc_stream_wait (self .backward_node .moe_group .id )
852
859
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
853
860
inputs = self .forward_node .combine_forward (inputs , async_finish = True )
854
861
paddle .base .core .nvprof_nvtx_pop ()
862
+ combine_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
863
+
864
+
865
+ dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
855
866
paddle .base .core .nvprof_nvtx_push ("attn_backward" )
856
867
output_grad = self .backward_node .attn_backward (output_grad )
857
868
paddle .base .core .nvprof_nvtx_pop ()
858
869
859
- calc_stream_wait (self .forward_node .moe_group .id )
870
+ combine_forward_event . calc_stream_wait (self .forward_node .moe_group .id )
860
871
paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
861
872
inputs = self .forward_node .post_process_forward (inputs )
862
873
paddle .base .core .nvprof_nvtx_pop ()
0 commit comments