diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 8769c23899d1..ffe6c3d57201 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -1077,7 +1077,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p # get dispatch backward event dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) - paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + paddle.base.core.nvprof_nvtx_push("mlp_backward_dw") WeightGradStore.pop() assert WeightGradStore.funcs_queue.empty() paddle.base.core.nvprof_nvtx_pop() @@ -1108,11 +1108,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p if pp_stream is not None: send_recv_stream = paddle.device.Stream(stream_base=pp_stream) - - # combine_forward_event.custom_stream_wait( pp_stream) - # final_out_event.custom_stream_wait(pp_stream) - - paddle.base.core.nvprof_nvtx_push("pp stream add") + paddle.base.core.nvprof_nvtx_push("pp_stream_add") with paddle.device.stream_guard(send_recv_stream): combine_forward_event.current_stream_wait() @@ -1127,20 +1123,27 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id) - paddle.base.core.nvprof_nvtx_push("attn_backward") + paddle.base.core.nvprof_nvtx_push("attn_backward_dx") assert WeightGradStore.funcs_queue.empty() WeightGradStore.enabled = True output_grad = self.backward_node.attn_backward(output_grad) event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() if EventStore is not None: EventStore.set(event_to_wait) + if pp_stream is not None: + # TODO(liangshuhao): this wait seems unnecessary, but there will be + # convergence issue without this. + with paddle.device.stream_guard(send_recv_stream): + event_to_wait.current_stream_wait() + + paddle.base.core.nvprof_nvtx_push("attn_backward_dw") WeightGradStore.enabled = False WeightGradStore.flush() WeightGradStore.pop() assert WeightGradStore.funcs_queue.empty() - paddle.base.core.nvprof_nvtx_pop() # residual add @@ -1218,6 +1221,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p paddle.base.core.nvprof_nvtx_pop() # moe_attn paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + if combine_bw_event_to_wait is not None: + combine_bw_event_to_wait.calc_stream_wait(self.forward_node.moe_group.id) output_grad = self.backward_node.mlp_node.backward(output_grad) inputs = self.forward_node.dispatch_forward( inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True