Skip to content

Commit b8faf66

Browse files
authored
support send recv overlap pre code (#10872)
1 parent b099c88 commit b8faf66

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
801801
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
802802
paddle.base.core.nvprof_nvtx_push("attn_backward")
803803
output_grad = self.backward_node.attn_backward(output_grad)
804-
event_to_wait = paddle.device.current_stream().record_event()
804+
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
805805
paddle.base.core.nvprof_nvtx_pop()
806806

807807
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
@@ -1519,7 +1519,8 @@ def overlapped_forward_backward(
15191519
backward_loss_fn_node,
15201520
backward_input_grads,
15211521
scaler,
1522-
event_to_wait=None,
1522+
combine_bw_event_to_wait = None,
1523+
pp_stream=None
15231524
):
15241525
if backward_loss_fn_node is not None:
15251526
if scaler:
@@ -1537,7 +1538,7 @@ def overlapped_forward_backward(
15371538
forward_inputs = forward_pre_node.forward(forward_inputs)
15381539
backward_input_grads = backward_pre_node.backward(backward_input_grads)
15391540
forward_inputs, backward_input_grads, _ = overlap_node.forward_backward(
1540-
forward_inputs, backward_input_grads, event_to_wait
1541+
forward_inputs, backward_input_grads, combine_bw_event_to_wait
15411542
)
15421543
forward_inputs = forward_post_node.forward(forward_inputs)
15431544
backward_input_grads = backward_post_node.backward(backward_input_grads)

paddlenlp/transformers/moe_layer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,10 @@ def backward(self, output_combine_grad, previous_event=None, async_finish=False,
648648

649649

650650
class Fp8CombineQuantNode:
651-
def __init__(self, token_dispatcher, name="fp8_combine_quant_node"):
651+
def __init__(self, token_dispatcher, moe_group=None, name="fp8_combine_quant_node"):
652652
self.token_dispatcher = token_dispatcher
653653
self.name = name
654+
self.moe_group = moe_group
654655

655656
@paddle.no_grad()
656657
def forward(self, output_combine):
@@ -666,13 +667,14 @@ def backward(self, output_grad, event_to_wait=None):
666667
# post combine grad
667668
if DSV3_USE_FP8_DISPATCH:
668669
if event_to_wait is not None:
670+
assert self.moe_group is not None
671+
event_to_wait.comm_stream_wait( self.moe_group.id)
669672
buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad))
670673
custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream())
671-
custom_stream.wait_event(event_to_wait)
672674
else:
673675
custom_stream = paddle.device.current_stream()
674676
with paddle.device.stream_guard(custom_stream):
675-
output_combine_grad = paddle.reshape(output_grad, self.output_combine_shape)
677+
output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]])
676678
# output_combine_grad quant to fp8
677679
output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
678680
output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False
@@ -683,7 +685,7 @@ def backward(self, output_grad, event_to_wait=None):
683685
quant_event = deep_ep.get_event_from_custom_stream(custom_stream.stream_base)
684686
return (output_combine_grad_fp8, output_combine_grad_scale), quant_event
685687
else:
686-
output_combine_grad = paddle.reshape(output_grad, self.output_combine_shape)
688+
output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]])
687689
return output_combine_grad, None
688690

689691

@@ -873,7 +875,7 @@ def __init__(
873875
is_split_group_gemm=is_split_group_gemm,
874876
)
875877
self.combine_node = Fp8CombineNode(self.token_dispatcher)
876-
self.combine_quant_node = Fp8CombineQuantNode(self.token_dispatcher)
878+
self.combine_quant_node = Fp8CombineQuantNode(self.token_dispatcher, custom_map.moe_group)
877879
self.name = name
878880

879881
@paddle.no_grad()

0 commit comments

Comments
 (0)