Skip to content

Commit de9c524

Browse files
authored
Fix slow convergence issue when enable overlap_p2p_comm (#10992)
1 parent 76cb4b7 commit de9c524

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
10861086
paddle.base.core.nvprof_nvtx_push("mlp_forward")
10871087
inputs = self.forward_node.mlp_forward(inputs)
10881088
paddle.base.core.nvprof_nvtx_pop()
1089-
mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
10901089

10911090
if pp_stream is not None:
10921091
paddle.base.core.nvprof_nvtx_push("post_process_forward")
@@ -1098,7 +1097,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
10981097

10991098
paddle.base.core.nvprof_nvtx_push("combine_forward")
11001099
inputs = self.forward_node.combine_forward(
1101-
inputs, previous_event=mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True
1100+
inputs, previous_event=final_out_event, async_finish=True, allocate_on_comm_stream=True
11021101
)
11031102
paddle.base.core.nvprof_nvtx_pop()
11041103

@@ -1109,9 +1108,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11091108
if pp_stream is not None:
11101109
send_recv_stream = paddle.device.Stream(stream_base=pp_stream)
11111110

1112-
# combine_forward_event.custom_stream_wait( pp_stream)
1113-
# final_out_event.custom_stream_wait(pp_stream)
1114-
11151111
paddle.base.core.nvprof_nvtx_push("pp stream add")
11161112

11171113
with paddle.device.stream_guard(send_recv_stream):
@@ -1228,6 +1224,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12281224
paddle.base.core.nvprof_nvtx_pop() # moe_attn
12291225

12301226
paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch")
1227+
if combine_bw_event_to_wait is not None:
1228+
combine_bw_event_to_wait.calc_stream_wait(self.forward_node.moe_group.id)
12311229
output_grad = self.backward_node.mlp_node.backward(output_grad)
12321230
inputs = self.forward_node.dispatch_forward(
12331231
inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True

0 commit comments

Comments
 (0)