@@ -1086,7 +1086,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1086
1086
paddle .base .core .nvprof_nvtx_push ("mlp_forward" )
1087
1087
inputs = self .forward_node .mlp_forward (inputs )
1088
1088
paddle .base .core .nvprof_nvtx_pop ()
1089
- mlp_fwd_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
1090
1089
1091
1090
if pp_stream is not None :
1092
1091
paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
@@ -1098,7 +1097,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1098
1097
1099
1098
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
1100
1099
inputs = self .forward_node .combine_forward (
1101
- inputs , previous_event = mlp_fwd_event , async_finish = True , allocate_on_comm_stream = True
1100
+ inputs , previous_event = final_out_event , async_finish = True , allocate_on_comm_stream = True
1102
1101
)
1103
1102
paddle .base .core .nvprof_nvtx_pop ()
1104
1103
@@ -1109,9 +1108,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1109
1108
if pp_stream is not None :
1110
1109
send_recv_stream = paddle .device .Stream (stream_base = pp_stream )
1111
1110
1112
- # combine_forward_event.custom_stream_wait( pp_stream)
1113
- # final_out_event.custom_stream_wait(pp_stream)
1114
-
1115
1111
paddle .base .core .nvprof_nvtx_push ("pp stream add" )
1116
1112
1117
1113
with paddle .device .stream_guard (send_recv_stream ):
@@ -1228,6 +1224,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1228
1224
paddle .base .core .nvprof_nvtx_pop () # moe_attn
1229
1225
1230
1226
paddle .base .core .nvprof_nvtx_push ("dense_mlp_moe_dispatch" )
1227
+ if combine_bw_event_to_wait is not None :
1228
+ combine_bw_event_to_wait .calc_stream_wait (self .forward_node .moe_group .id )
1231
1229
output_grad = self .backward_node .mlp_node .backward (output_grad )
1232
1230
inputs = self .forward_node .dispatch_forward (
1233
1231
inputs , previous_event = attn_fw_event , async_finish = True , allocate_on_comm_stream = True
0 commit comments