@@ -396,9 +396,9 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
396396 for f , b in zip (forward_nodes , backward_nodes ):
397397 self .nodes .append (schedule_node_class (f , b , f"OverlapedNode_{ len (self .nodes )} " ))
398398
399- def forward_backward (self , inputs , output_grad ):
399+ def forward_backward (self , inputs , output_grad , event_to_wait = None ):
400400 for n in self .nodes :
401- inputs , output_grad = n .forward_backward (inputs , output_grad )
401+ inputs , output_grad , event_to_wait = n .forward_backward (inputs , output_grad , event_to_wait )
402402 return inputs , output_grad
403403
404404
@@ -409,7 +409,7 @@ def __init__(self, forward_node, backward_node, name=""):
409409 self .backward_node = backward_node
410410 self .name = name
411411
412- def forward_backward (self , inputs , output_grad ):
412+ def forward_backward (self , inputs , output_grad , event_to_wait = None ):
413413 paddle .base .core .nvprof_nvtx_push ("forward_backward" )
414414 output_grad = self .backward_node .post_process_node .backward (output_grad )
415415
@@ -594,7 +594,7 @@ def post_process_forward(self, inputs):
594594
595595 return inputs
596596
597- def post_process_backward (self , output_grad ):
597+ def post_process_backward (self , output_grad , event_to_wait = None ):
598598 if self .send_mtp_embed :
599599 (
600600 inputs_embeds_mtp_grad ,
@@ -610,43 +610,51 @@ def post_process_backward(self, output_grad):
610610 l_aux_grad ,
611611 final_hidden_states_grad ,
612612 ) = self .post_process_node .backward (output_grad )
613- output_combine_grad = self .fp8_fusion_moe_node .combine_quant_node .backward (final_hidden_states_grad )
613+ output_combine_grad , quant_event = self .fp8_fusion_moe_node .combine_quant_node .backward (
614+ final_hidden_states_grad , event_to_wait
615+ )
614616 if self .send_mtp_embed :
615617 return (
616618 inputs_embeds_mtp_grad ,
617619 hidden_states_grad ,
618620 residual_grad ,
619621 l_aux_grad ,
620622 output_combine_grad ,
623+ quant_event ,
621624 )
622625 else :
623626 return (
624627 hidden_states_grad ,
625628 residual_grad ,
626629 l_aux_grad ,
627630 output_combine_grad ,
631+ quant_event ,
628632 )
629633
630- def combine_backward (self , output_grad , async_finish = False ):
634+ def combine_backward (self , output_grad , async_finish = False , allocate_on_comm_stream = False ):
631635 if self .send_mtp_embed :
632636 (
633637 inputs_embeds_mtp_grad ,
634638 hidden_states_grad ,
635639 residual_grad ,
636640 l_aux_grad ,
637641 output_combine_grad ,
642+ quant_event ,
638643 ) = output_grad
639644 else :
640645 (
641646 hidden_states_grad ,
642647 residual_grad ,
643648 l_aux_grad ,
644649 output_combine_grad ,
650+ quant_event ,
645651 ) = output_grad
646652
647653 hidden_states_out_grad = self .fp8_fusion_moe_node .combine_node .backward (
648654 output_combine_grad ,
649655 async_finish = async_finish ,
656+ previous_event = quant_event ,
657+ allocate_on_comm_stream = allocate_on_comm_stream ,
650658 )
651659
652660 if self .send_mtp_embed :
@@ -811,35 +819,36 @@ def __init__(self, forward_node, backward_node, name=""):
811819 self .backward_node = backward_node
812820 self .name = name
813821
814- def forward_backward (self , inputs , output_grad ):
822+ def forward_backward (self , inputs , output_grad , event_to_wait = None ):
815823 paddle .base .core .nvprof_nvtx_push ("forward_backward" )
816824
817825 paddle .base .core .nvprof_nvtx_push ("post_process_backward" )
818- output_grad = self .backward_node .post_process_backward (output_grad )
826+ output_grad = self .backward_node .post_process_backward (output_grad , event_to_wait )
819827 paddle .base .core .nvprof_nvtx_pop ()
820828
821829 paddle .base .core .nvprof_nvtx_push ("combine_backward" )
822- output_grad = self .backward_node .combine_backward (output_grad , async_finish = True )
830+ output_grad = self .backward_node .combine_backward (output_grad , async_finish = True , allocate_on_comm_stream = True )
823831 # get combine event
824- combine_backward_event = deep_ep .get_event_from_comm_stream ( self .backward_node .moe_group .id )
832+ combine_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
825833 paddle .base .core .nvprof_nvtx_pop ()
826834
827835 paddle .base .core .nvprof_nvtx_push ("attn_forward" )
828836 inputs = self .forward_node .attn_forward (inputs )
829837 paddle .base .core .nvprof_nvtx_pop ()
830- attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
831838
839+ attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
832840
833- combine_backward_event .calc_stream_wait ( self .backward_node .moe_group .id )
841+ combine_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
834842 paddle .base .core .nvprof_nvtx_push ("mlp_backward_dx" )
835843 output_grad = self .backward_node .mlp_backward (output_grad )
836844 paddle .base .core .nvprof_nvtx_pop ()
845+
837846 paddle .base .core .nvprof_nvtx_push ("dispatch_forward" )
838847 inputs = self .forward_node .dispatch_forward (
839848 inputs , previous_event = attn_compute_event , async_finish = True , allocate_on_comm_stream = True
840849 )
841850 paddle .base .core .nvprof_nvtx_pop ()
842- dispatch_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
851+ dispatch_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
843852
844853 paddle .base .core .nvprof_nvtx_push ("dispatch_backward" )
845854 output_grad = self .backward_node .dispatch_backward (output_grad , async_finish = True )
@@ -851,28 +860,28 @@ def forward_backward(self, inputs, output_grad):
851860 self .backward_node .mlp_backward_dw ()
852861 paddle .base .core .nvprof_nvtx_pop ()
853862
854- dispatch_forward_event .calc_stream_wait ( self .forward_node .moe_group .id )
863+ dispatch_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
855864 paddle .base .core .nvprof_nvtx_push ("mlp_forward" )
856865 inputs = self .forward_node .mlp_forward (inputs )
857866 paddle .base .core .nvprof_nvtx_pop ()
858867
859868 paddle .base .core .nvprof_nvtx_push ("combine_forward" )
860869 inputs = self .forward_node .combine_forward (inputs , async_finish = True )
861870 paddle .base .core .nvprof_nvtx_pop ()
862- combine_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
863-
871+ combine_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
864872
865873 dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
866874 paddle .base .core .nvprof_nvtx_push ("attn_backward" )
867875 output_grad = self .backward_node .attn_backward (output_grad )
876+ event_to_wait = paddle .device .current_stream ().record_event ()
868877 paddle .base .core .nvprof_nvtx_pop ()
869878
870879 combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
871880 paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
872881 inputs = self .forward_node .post_process_forward (inputs )
873882 paddle .base .core .nvprof_nvtx_pop ()
874883 paddle .base .core .nvprof_nvtx_pop ()
875- return inputs , output_grad
884+ return inputs , output_grad , event_to_wait
876885
877886
878887def build_overlapped_nodes (forward_chunk , backward_chunk ):
@@ -1579,6 +1588,7 @@ def overlapped_forward_backward(
15791588 backward_loss_fn_node ,
15801589 backward_input_grads ,
15811590 scaler ,
1591+ event_to_wait = None ,
15821592 ):
15831593 if backward_loss_fn_node is not None :
15841594 if scaler :
@@ -1595,7 +1605,9 @@ def overlapped_forward_backward(
15951605 ) = build_overlapped_nodes (forward_chunk , backward_chunk )
15961606 forward_inputs = forward_pre_node .forward (forward_inputs )
15971607 backward_input_grads = backward_pre_node .backward (backward_input_grads )
1598- forward_inputs , backward_input_grads = overlap_node .forward_backward (forward_inputs , backward_input_grads )
1608+ forward_inputs , backward_input_grads = overlap_node .forward_backward (
1609+ forward_inputs , backward_input_grads , event_to_wait
1610+ )
15991611 forward_inputs = forward_post_node .forward (forward_inputs )
16001612 backward_input_grads = backward_post_node .backward (backward_input_grads )
16011613
0 commit comments