@@ -1077,7 +1077,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1077
1077
# get dispatch backward event
1078
1078
dispatch_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
1079
1079
1080
- paddle .base .core .nvprof_nvtx_push ("dispatch_backward_dw " )
1080
+ paddle .base .core .nvprof_nvtx_push ("mlp_backward_dw " )
1081
1081
WeightGradStore .pop ()
1082
1082
assert WeightGradStore .funcs_queue .empty ()
1083
1083
paddle .base .core .nvprof_nvtx_pop ()
@@ -1108,11 +1108,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1108
1108
1109
1109
if pp_stream is not None :
1110
1110
send_recv_stream = paddle .device .Stream (stream_base = pp_stream )
1111
-
1112
- # combine_forward_event.custom_stream_wait( pp_stream)
1113
- # final_out_event.custom_stream_wait(pp_stream)
1114
-
1115
- paddle .base .core .nvprof_nvtx_push ("pp stream add" )
1111
+ paddle .base .core .nvprof_nvtx_push ("pp_stream_add" )
1116
1112
1117
1113
with paddle .device .stream_guard (send_recv_stream ):
1118
1114
combine_forward_event .current_stream_wait ()
@@ -1127,20 +1123,27 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1127
1123
1128
1124
dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
1129
1125
1130
- paddle .base .core .nvprof_nvtx_push ("attn_backward " )
1126
+ paddle .base .core .nvprof_nvtx_push ("attn_backward_dx " )
1131
1127
assert WeightGradStore .funcs_queue .empty ()
1132
1128
WeightGradStore .enabled = True
1133
1129
output_grad = self .backward_node .attn_backward (output_grad )
1134
1130
event_to_wait = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
1131
+ paddle .base .core .nvprof_nvtx_pop ()
1135
1132
1136
1133
if EventStore is not None :
1137
1134
EventStore .set (event_to_wait )
1138
1135
1136
+ if pp_stream is not None :
1137
+ # TODO(liangshuhao): this wait seems unnecessary, but there will be
1138
+ # convergence issue without this.
1139
+ with paddle .device .stream_guard (send_recv_stream ):
1140
+ event_to_wait .current_stream_wait ()
1141
+
1142
+ paddle .base .core .nvprof_nvtx_push ("attn_backward_dw" )
1139
1143
WeightGradStore .enabled = False
1140
1144
WeightGradStore .flush ()
1141
1145
WeightGradStore .pop ()
1142
1146
assert WeightGradStore .funcs_queue .empty ()
1143
-
1144
1147
paddle .base .core .nvprof_nvtx_pop ()
1145
1148
1146
1149
# residual add
@@ -1218,6 +1221,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1218
1221
paddle .base .core .nvprof_nvtx_pop () # moe_attn
1219
1222
1220
1223
paddle .base .core .nvprof_nvtx_push ("dense_mlp_moe_dispatch" )
1224
+ if combine_bw_event_to_wait is not None :
1225
+ combine_bw_event_to_wait .calc_stream_wait (self .forward_node .moe_group .id )
1221
1226
output_grad = self .backward_node .mlp_node .backward (output_grad )
1222
1227
inputs = self .forward_node .dispatch_forward (
1223
1228
inputs , previous_event = attn_fw_event , async_finish = True , allocate_on_comm_stream = True
0 commit comments