@@ -569,13 +569,18 @@ def mlp_forward(self, inputs):
569
569
ret = (inputs_embeds_mtp , * ret ) if self .send_mtp_embed else ret
570
570
return ret
571
571
572
- def combine_forward (self , inputs , async_finish = False ):
572
+ def combine_forward (self , inputs , async_finish = False , previous_event = None , allocate_on_comm_stream = False ):
573
573
if self .send_mtp_embed :
574
574
(inputs_embeds_mtp , hidden_states , residual , l_aux , hidden_states_out ) = inputs
575
575
else :
576
576
(hidden_states , residual , l_aux , hidden_states_out ) = inputs
577
577
578
- output_combine = self .fp8_fusion_moe_node .combine_node .forward (hidden_states_out , async_finish = async_finish )
578
+ output_combine = self .fp8_fusion_moe_node .combine_node .forward (
579
+ hidden_states_out ,
580
+ async_finish = async_finish ,
581
+ previous_event = previous_event ,
582
+ allocate_on_comm_stream = allocate_on_comm_stream and previous_event is not None ,
583
+ )
579
584
580
585
ret = (hidden_states , residual , l_aux , output_combine )
581
586
@@ -652,7 +657,7 @@ def mlp_backward(self, output_grad):
652
657
ret = (inputs_embeds_mtp_grad , * ret ) if self .send_mtp_embed else ret
653
658
return ret
654
659
655
- def dispatch_backward (self , output_grad , async_finish = False ):
660
+ def dispatch_backward (self , output_grad , async_finish = False , previous_event = None , allocate_on_comm_stream = False ):
656
661
if self .send_mtp_embed :
657
662
(
658
663
inputs_embeds_mtp_grad ,
@@ -666,7 +671,11 @@ def dispatch_backward(self, output_grad, async_finish=False):
666
671
hidden_states_grad , residual_grad , l_aux_grad , hs_dispatched_grad , dispatched_probs_grad = output_grad
667
672
668
673
hs_grad , token_probs_grad = self .fp8_fusion_moe_node .dispatch_node .backward (
669
- hs_dispatched_grad , dispatched_probs_grad , async_finish = async_finish
674
+ hs_dispatched_grad ,
675
+ dispatched_probs_grad ,
676
+ async_finish = async_finish ,
677
+ previous_event = previous_event ,
678
+ allocate_on_comm_stream = allocate_on_comm_stream and previous_event is not None ,
670
679
)
671
680
672
681
ret = (hidden_states_grad , residual_grad , l_aux_grad , hs_grad , token_probs_grad )
@@ -755,6 +764,8 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
755
764
output_grad = self .backward_node .mlp_backward (output_grad )
756
765
paddle .base .core .nvprof_nvtx_pop ()
757
766
767
+ output_grad_event = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
768
+
758
769
paddle .base .core .nvprof_nvtx_push ("dispatch_forward" )
759
770
inputs = self .forward_node .dispatch_forward (
760
771
inputs , previous_event = attn_compute_event , async_finish = True , allocate_on_comm_stream = True
@@ -763,7 +774,9 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
763
774
dispatch_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
764
775
765
776
paddle .base .core .nvprof_nvtx_push ("dispatch_backward" )
766
- output_grad = self .backward_node .dispatch_backward (output_grad , async_finish = True )
777
+ output_grad = self .backward_node .dispatch_backward (
778
+ output_grad , async_finish = True , previous_event = output_grad_event , allocate_on_comm_stream = True
779
+ )
767
780
paddle .base .core .nvprof_nvtx_pop ()
768
781
# get dispatch backward event
769
782
dispatch_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
@@ -777,8 +790,12 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
777
790
inputs = self .forward_node .mlp_forward (inputs )
778
791
paddle .base .core .nvprof_nvtx_pop ()
779
792
793
+ inputs_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
794
+
780
795
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
781
- inputs = self .forward_node .combine_forward (inputs , async_finish = True )
796
+ inputs = self .forward_node .combine_forward (
797
+ inputs , async_finish = True , previous_event = inputs_event , allocate_on_comm_stream = True
798
+ )
782
799
paddle .base .core .nvprof_nvtx_pop ()
783
800
combine_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
784
801
0 commit comments