Skip to content

Commit 5ccd5a8

Browse files
authored
fix dual pp event wait bug (#10829)
1 parent 8a5ea21 commit 5ccd5a8

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -820,13 +820,17 @@ def forward_backward(self, inputs, output_grad):
820820

821821
paddle.base.core.nvprof_nvtx_push("combine_backward")
822822
output_grad = self.backward_node.combine_backward(output_grad, async_finish=True)
823+
# get combine event
824+
combine_backward_event = deep_ep.get_event_from_comm_stream( self.backward_node.moe_group.id)
823825
paddle.base.core.nvprof_nvtx_pop()
826+
824827
paddle.base.core.nvprof_nvtx_push("attn_forward")
825828
inputs = self.forward_node.attn_forward(inputs)
826829
paddle.base.core.nvprof_nvtx_pop()
827-
828-
calc_stream_wait(self.backward_node.moe_group.id)
829830
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
831+
832+
833+
combine_backward_event.calc_stream_wait( self.backward_node.moe_group.id )
830834
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
831835
output_grad = self.backward_node.mlp_backward(output_grad)
832836
paddle.base.core.nvprof_nvtx_pop()
@@ -835,28 +839,35 @@ def forward_backward(self, inputs, output_grad):
835839
inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True
836840
)
837841
paddle.base.core.nvprof_nvtx_pop()
842+
dispatch_forward_event = deep_ep.get_event_from_comm_stream( self.forward_node.moe_group.id )
843+
838844
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
839845
output_grad = self.backward_node.dispatch_backward(output_grad, async_finish=True)
840846
paddle.base.core.nvprof_nvtx_pop()
847+
# get dispatch backward event
848+
dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
841849

842850
paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw")
843851
self.backward_node.mlp_backward_dw()
844852
paddle.base.core.nvprof_nvtx_pop()
845853

846-
calc_stream_wait(self.forward_node.moe_group.id)
854+
dispatch_forward_event.calc_stream_wait( self.forward_node.moe_group.id)
847855
paddle.base.core.nvprof_nvtx_push("mlp_forward")
848856
inputs = self.forward_node.mlp_forward(inputs)
849857
paddle.base.core.nvprof_nvtx_pop()
850858

851-
calc_stream_wait(self.backward_node.moe_group.id)
852859
paddle.base.core.nvprof_nvtx_push("combine_forward")
853860
inputs = self.forward_node.combine_forward(inputs, async_finish=True)
854861
paddle.base.core.nvprof_nvtx_pop()
862+
combine_forward_event = deep_ep.get_event_from_comm_stream( self.forward_node.moe_group.id)
863+
864+
865+
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
855866
paddle.base.core.nvprof_nvtx_push("attn_backward")
856867
output_grad = self.backward_node.attn_backward(output_grad)
857868
paddle.base.core.nvprof_nvtx_pop()
858869

859-
calc_stream_wait(self.forward_node.moe_group.id)
870+
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
860871
paddle.base.core.nvprof_nvtx_push("post_process_forward")
861872
inputs = self.forward_node.post_process_forward(inputs)
862873
paddle.base.core.nvprof_nvtx_pop()

0 commit comments

Comments
 (0)