Skip to content

Commit 1b1e63a

Browse files
committed
Make pp_stream wait on attn_backward_dx
1 parent 73f451c commit 1b1e63a

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
10771077
# get dispatch backward event
10781078
dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
10791079

1080-
paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw")
1080+
paddle.base.core.nvprof_nvtx_push("mlp_backward_dw")
10811081
WeightGradStore.pop()
10821082
assert WeightGradStore.funcs_queue.empty()
10831083
paddle.base.core.nvprof_nvtx_pop()
@@ -1108,11 +1108,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11081108

11091109
if pp_stream is not None:
11101110
send_recv_stream = paddle.device.Stream(stream_base=pp_stream)
1111-
1112-
# combine_forward_event.custom_stream_wait( pp_stream)
1113-
# final_out_event.custom_stream_wait(pp_stream)
1114-
1115-
paddle.base.core.nvprof_nvtx_push("pp stream add")
1111+
paddle.base.core.nvprof_nvtx_push("pp_stream_add")
11161112

11171113
with paddle.device.stream_guard(send_recv_stream):
11181114
combine_forward_event.current_stream_wait()
@@ -1127,20 +1123,27 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11271123

11281124
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
11291125

1130-
paddle.base.core.nvprof_nvtx_push("attn_backward")
1126+
paddle.base.core.nvprof_nvtx_push("attn_backward_dx")
11311127
assert WeightGradStore.funcs_queue.empty()
11321128
WeightGradStore.enabled = True
11331129
output_grad = self.backward_node.attn_backward(output_grad)
11341130
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
1131+
paddle.base.core.nvprof_nvtx_pop()
11351132

11361133
if EventStore is not None:
11371134
EventStore.set(event_to_wait)
11381135

1136+
if pp_stream is not None:
1137+
# TODO(liangshuhao): this wait seems unnecessary, but there will be
1138+
# convergence issue without this.
1139+
with paddle.device.stream_guard(send_recv_stream):
1140+
event_to_wait.current_stream_wait()
1141+
1142+
paddle.base.core.nvprof_nvtx_push("attn_backward_dw")
11391143
WeightGradStore.enabled = False
11401144
WeightGradStore.flush()
11411145
WeightGradStore.pop()
11421146
assert WeightGradStore.funcs_queue.empty()
1143-
11441147
paddle.base.core.nvprof_nvtx_pop()
11451148

11461149
# residual add
@@ -1218,6 +1221,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12181221
paddle.base.core.nvprof_nvtx_pop() # moe_attn
12191222

12201223
paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch")
1224+
if combine_bw_event_to_wait is not None:
1225+
combine_bw_event_to_wait.calc_stream_wait(self.forward_node.moe_group.id)
12211226
output_grad = self.backward_node.mlp_node.backward(output_grad)
12221227
inputs = self.forward_node.dispatch_forward(
12231228
inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True

0 commit comments

Comments
 (0)