@@ -396,9 +396,9 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
396
396
for f , b in zip (forward_nodes , backward_nodes ):
397
397
self .nodes .append (schedule_node_class (f , b , f"OverlapedNode_{ len (self .nodes )} " ))
398
398
399
- def forward_backward (self , inputs , output_grad ):
399
+ def forward_backward (self , inputs , output_grad , event_to_wait = None ):
400
400
for n in self .nodes :
401
- inputs , output_grad = n .forward_backward (inputs , output_grad )
401
+ inputs , output_grad , event_to_wait = n .forward_backward (inputs , output_grad , event_to_wait )
402
402
return inputs , output_grad
403
403
404
404
@@ -409,7 +409,7 @@ def __init__(self, forward_node, backward_node, name=""):
409
409
self .backward_node = backward_node
410
410
self .name = name
411
411
412
- def forward_backward (self , inputs , output_grad ):
412
+ def forward_backward (self , inputs , output_grad , event_to_wait = None ):
413
413
paddle .base .core .nvprof_nvtx_push ("forward_backward" )
414
414
output_grad = self .backward_node .post_process_node .backward (output_grad )
415
415
@@ -594,7 +594,7 @@ def post_process_forward(self, inputs):
594
594
595
595
return inputs
596
596
597
- def post_process_backward (self , output_grad ):
597
+ def post_process_backward (self , output_grad , event_to_wait = None ):
598
598
if self .send_mtp_embed :
599
599
(
600
600
inputs_embeds_mtp_grad ,
@@ -610,43 +610,51 @@ def post_process_backward(self, output_grad):
610
610
l_aux_grad ,
611
611
final_hidden_states_grad ,
612
612
) = self .post_process_node .backward (output_grad )
613
- output_combine_grad = self .fp8_fusion_moe_node .combine_quant_node .backward (final_hidden_states_grad )
613
+ output_combine_grad , quant_event = self .fp8_fusion_moe_node .combine_quant_node .backward (
614
+ final_hidden_states_grad , event_to_wait
615
+ )
614
616
if self .send_mtp_embed :
615
617
return (
616
618
inputs_embeds_mtp_grad ,
617
619
hidden_states_grad ,
618
620
residual_grad ,
619
621
l_aux_grad ,
620
622
output_combine_grad ,
623
+ quant_event ,
621
624
)
622
625
else :
623
626
return (
624
627
hidden_states_grad ,
625
628
residual_grad ,
626
629
l_aux_grad ,
627
630
output_combine_grad ,
631
+ quant_event ,
628
632
)
629
633
630
- def combine_backward (self , output_grad , async_finish = False ):
634
+ def combine_backward (self , output_grad , async_finish = False , allocate_on_comm_stream = False ):
631
635
if self .send_mtp_embed :
632
636
(
633
637
inputs_embeds_mtp_grad ,
634
638
hidden_states_grad ,
635
639
residual_grad ,
636
640
l_aux_grad ,
637
641
output_combine_grad ,
642
+ quant_event ,
638
643
) = output_grad
639
644
else :
640
645
(
641
646
hidden_states_grad ,
642
647
residual_grad ,
643
648
l_aux_grad ,
644
649
output_combine_grad ,
650
+ quant_event ,
645
651
) = output_grad
646
652
647
653
hidden_states_out_grad = self .fp8_fusion_moe_node .combine_node .backward (
648
654
output_combine_grad ,
649
655
async_finish = async_finish ,
656
+ previous_event = quant_event ,
657
+ allocate_on_comm_stream = allocate_on_comm_stream ,
650
658
)
651
659
652
660
if self .send_mtp_embed :
@@ -811,35 +819,36 @@ def __init__(self, forward_node, backward_node, name=""):
811
819
self .backward_node = backward_node
812
820
self .name = name
813
821
814
- def forward_backward (self , inputs , output_grad ):
822
+ def forward_backward (self , inputs , output_grad , event_to_wait = None ):
815
823
paddle .base .core .nvprof_nvtx_push ("forward_backward" )
816
824
817
825
paddle .base .core .nvprof_nvtx_push ("post_process_backward" )
818
- output_grad = self .backward_node .post_process_backward (output_grad )
826
+ output_grad = self .backward_node .post_process_backward (output_grad , event_to_wait )
819
827
paddle .base .core .nvprof_nvtx_pop ()
820
828
821
829
paddle .base .core .nvprof_nvtx_push ("combine_backward" )
822
- output_grad = self .backward_node .combine_backward (output_grad , async_finish = True )
830
+ output_grad = self .backward_node .combine_backward (output_grad , async_finish = True , allocate_on_comm_stream = True )
823
831
# get combine event
824
- combine_backward_event = deep_ep .get_event_from_comm_stream ( self .backward_node .moe_group .id )
832
+ combine_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
825
833
paddle .base .core .nvprof_nvtx_pop ()
826
834
827
835
paddle .base .core .nvprof_nvtx_push ("attn_forward" )
828
836
inputs = self .forward_node .attn_forward (inputs )
829
837
paddle .base .core .nvprof_nvtx_pop ()
830
- attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
831
838
839
+ attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
832
840
833
- combine_backward_event .calc_stream_wait ( self .backward_node .moe_group .id )
841
+ combine_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
834
842
paddle .base .core .nvprof_nvtx_push ("mlp_backward_dx" )
835
843
output_grad = self .backward_node .mlp_backward (output_grad )
836
844
paddle .base .core .nvprof_nvtx_pop ()
845
+
837
846
paddle .base .core .nvprof_nvtx_push ("dispatch_forward" )
838
847
inputs = self .forward_node .dispatch_forward (
839
848
inputs , previous_event = attn_compute_event , async_finish = True , allocate_on_comm_stream = True
840
849
)
841
850
paddle .base .core .nvprof_nvtx_pop ()
842
- dispatch_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
851
+ dispatch_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
843
852
844
853
paddle .base .core .nvprof_nvtx_push ("dispatch_backward" )
845
854
output_grad = self .backward_node .dispatch_backward (output_grad , async_finish = True )
@@ -851,28 +860,28 @@ def forward_backward(self, inputs, output_grad):
851
860
self .backward_node .mlp_backward_dw ()
852
861
paddle .base .core .nvprof_nvtx_pop ()
853
862
854
- dispatch_forward_event .calc_stream_wait ( self .forward_node .moe_group .id )
863
+ dispatch_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
855
864
paddle .base .core .nvprof_nvtx_push ("mlp_forward" )
856
865
inputs = self .forward_node .mlp_forward (inputs )
857
866
paddle .base .core .nvprof_nvtx_pop ()
858
867
859
868
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
860
869
inputs = self .forward_node .combine_forward (inputs , async_finish = True )
861
870
paddle .base .core .nvprof_nvtx_pop ()
862
- combine_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
863
-
871
+ combine_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
864
872
865
873
dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
866
874
paddle .base .core .nvprof_nvtx_push ("attn_backward" )
867
875
output_grad = self .backward_node .attn_backward (output_grad )
876
+ event_to_wait = paddle .device .current_stream ().record_event ()
868
877
paddle .base .core .nvprof_nvtx_pop ()
869
878
870
879
combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
871
880
paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
872
881
inputs = self .forward_node .post_process_forward (inputs )
873
882
paddle .base .core .nvprof_nvtx_pop ()
874
883
paddle .base .core .nvprof_nvtx_pop ()
875
- return inputs , output_grad
884
+ return inputs , output_grad , event_to_wait
876
885
877
886
878
887
def build_overlapped_nodes (forward_chunk , backward_chunk ):
@@ -1579,6 +1588,7 @@ def overlapped_forward_backward(
1579
1588
backward_loss_fn_node ,
1580
1589
backward_input_grads ,
1581
1590
scaler ,
1591
+ event_to_wait = None ,
1582
1592
):
1583
1593
if backward_loss_fn_node is not None :
1584
1594
if scaler :
@@ -1595,7 +1605,9 @@ def overlapped_forward_backward(
1595
1605
) = build_overlapped_nodes (forward_chunk , backward_chunk )
1596
1606
forward_inputs = forward_pre_node .forward (forward_inputs )
1597
1607
backward_input_grads = backward_pre_node .backward (backward_input_grads )
1598
- forward_inputs , backward_input_grads = overlap_node .forward_backward (forward_inputs , backward_input_grads )
1608
+ forward_inputs , backward_input_grads = overlap_node .forward_backward (
1609
+ forward_inputs , backward_input_grads , event_to_wait
1610
+ )
1599
1611
forward_inputs = forward_post_node .forward (forward_inputs )
1600
1612
backward_input_grads = backward_post_node .backward (backward_input_grads )
1601
1613
0 commit comments