Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
# get dispatch backward event
dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw")
paddle.base.core.nvprof_nvtx_push("mlp_backward_dw")
WeightGradStore.pop()
assert WeightGradStore.funcs_queue.empty()
paddle.base.core.nvprof_nvtx_pop()
Expand Down Expand Up @@ -1108,11 +1108,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p

if pp_stream is not None:
send_recv_stream = paddle.device.Stream(stream_base=pp_stream)

# combine_forward_event.custom_stream_wait( pp_stream)
# final_out_event.custom_stream_wait(pp_stream)

paddle.base.core.nvprof_nvtx_push("pp stream add")
paddle.base.core.nvprof_nvtx_push("pp_stream_add")

with paddle.device.stream_guard(send_recv_stream):
combine_forward_event.current_stream_wait()
Expand All @@ -1127,20 +1123,27 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p

dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_push("attn_backward")
paddle.base.core.nvprof_nvtx_push("attn_backward_dx")
assert WeightGradStore.funcs_queue.empty()
WeightGradStore.enabled = True
output_grad = self.backward_node.attn_backward(output_grad)
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_pop()

if EventStore is not None:
EventStore.set(event_to_wait)

if pp_stream is not None:
# TODO(liangshuhao): this wait seems unnecessary, but there will be
# convergence issue without this.
with paddle.device.stream_guard(send_recv_stream):
event_to_wait.current_stream_wait()

paddle.base.core.nvprof_nvtx_push("attn_backward_dw")
WeightGradStore.enabled = False
WeightGradStore.flush()
WeightGradStore.pop()
assert WeightGradStore.funcs_queue.empty()

paddle.base.core.nvprof_nvtx_pop()

# residual add
Expand Down Expand Up @@ -1218,6 +1221,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
paddle.base.core.nvprof_nvtx_pop() # moe_attn

paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch")
if combine_bw_event_to_wait is not None:
combine_bw_event_to_wait.calc_stream_wait(self.forward_node.moe_group.id)
output_grad = self.backward_node.mlp_node.backward(output_grad)
inputs = self.forward_node.dispatch_forward(
inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True
Expand Down
Loading