Skip to content

Commit 9b1c166

Browse files
authored
Fix post_process_backward calc stream wait (#11063)
1 parent 349c92f commit 9b1c166

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,9 @@ def post_process_forward(self, inputs, with_residual=True):
766766
return inputs
767767

768768
def post_process_backward(self, output_grad, event_to_wait=None):
769+
if event_to_wait is not None:
770+
event_to_wait.calc_stream_wait(self.moe_group.id)
771+
769772
grad = self.post_process_node.backward(output_grad)
770773

771774
if self.using_post_norm_recompute:

0 commit comments

Comments
 (0)