Skip to content

Commit c52eb6f

Browse files
authored
Refine fp8 quant and combine backward overlap (#10830)
* refine * refine * refine
1 parent c9ec40a commit c52eb6f

File tree

3 files changed

+65
-31
lines changed

3 files changed

+65
-31
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
396396
for f, b in zip(forward_nodes, backward_nodes):
397397
self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}"))
398398

399-
def forward_backward(self, inputs, output_grad):
399+
def forward_backward(self, inputs, output_grad, event_to_wait=None):
400400
for n in self.nodes:
401-
inputs, output_grad = n.forward_backward(inputs, output_grad)
401+
inputs, output_grad, event_to_wait = n.forward_backward(inputs, output_grad, event_to_wait)
402402
return inputs, output_grad
403403

404404

@@ -409,7 +409,7 @@ def __init__(self, forward_node, backward_node, name=""):
409409
self.backward_node = backward_node
410410
self.name = name
411411

412-
def forward_backward(self, inputs, output_grad):
412+
def forward_backward(self, inputs, output_grad, event_to_wait=None):
413413
paddle.base.core.nvprof_nvtx_push("forward_backward")
414414
output_grad = self.backward_node.post_process_node.backward(output_grad)
415415

@@ -594,7 +594,7 @@ def post_process_forward(self, inputs):
594594

595595
return inputs
596596

597-
def post_process_backward(self, output_grad):
597+
def post_process_backward(self, output_grad, event_to_wait=None):
598598
if self.send_mtp_embed:
599599
(
600600
inputs_embeds_mtp_grad,
@@ -610,43 +610,51 @@ def post_process_backward(self, output_grad):
610610
l_aux_grad,
611611
final_hidden_states_grad,
612612
) = self.post_process_node.backward(output_grad)
613-
output_combine_grad = self.fp8_fusion_moe_node.combine_quant_node.backward(final_hidden_states_grad)
613+
output_combine_grad, quant_event = self.fp8_fusion_moe_node.combine_quant_node.backward(
614+
final_hidden_states_grad, event_to_wait
615+
)
614616
if self.send_mtp_embed:
615617
return (
616618
inputs_embeds_mtp_grad,
617619
hidden_states_grad,
618620
residual_grad,
619621
l_aux_grad,
620622
output_combine_grad,
623+
quant_event,
621624
)
622625
else:
623626
return (
624627
hidden_states_grad,
625628
residual_grad,
626629
l_aux_grad,
627630
output_combine_grad,
631+
quant_event,
628632
)
629633

630-
def combine_backward(self, output_grad, async_finish=False):
634+
def combine_backward(self, output_grad, async_finish=False, allocate_on_comm_stream=False):
631635
if self.send_mtp_embed:
632636
(
633637
inputs_embeds_mtp_grad,
634638
hidden_states_grad,
635639
residual_grad,
636640
l_aux_grad,
637641
output_combine_grad,
642+
quant_event,
638643
) = output_grad
639644
else:
640645
(
641646
hidden_states_grad,
642647
residual_grad,
643648
l_aux_grad,
644649
output_combine_grad,
650+
quant_event,
645651
) = output_grad
646652

647653
hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward(
648654
output_combine_grad,
649655
async_finish=async_finish,
656+
previous_event=quant_event,
657+
allocate_on_comm_stream=allocate_on_comm_stream,
650658
)
651659

652660
if self.send_mtp_embed:
@@ -811,35 +819,36 @@ def __init__(self, forward_node, backward_node, name=""):
811819
self.backward_node = backward_node
812820
self.name = name
813821

814-
def forward_backward(self, inputs, output_grad):
822+
def forward_backward(self, inputs, output_grad, event_to_wait=None):
815823
paddle.base.core.nvprof_nvtx_push("forward_backward")
816824

817825
paddle.base.core.nvprof_nvtx_push("post_process_backward")
818-
output_grad = self.backward_node.post_process_backward(output_grad)
826+
output_grad = self.backward_node.post_process_backward(output_grad, event_to_wait)
819827
paddle.base.core.nvprof_nvtx_pop()
820828

821829
paddle.base.core.nvprof_nvtx_push("combine_backward")
822-
output_grad = self.backward_node.combine_backward(output_grad, async_finish=True)
830+
output_grad = self.backward_node.combine_backward(output_grad, async_finish=True, allocate_on_comm_stream=True)
823831
# get combine event
824-
combine_backward_event = deep_ep.get_event_from_comm_stream( self.backward_node.moe_group.id)
832+
combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
825833
paddle.base.core.nvprof_nvtx_pop()
826834

827835
paddle.base.core.nvprof_nvtx_push("attn_forward")
828836
inputs = self.forward_node.attn_forward(inputs)
829837
paddle.base.core.nvprof_nvtx_pop()
830-
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
831838

839+
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
832840

833-
combine_backward_event.calc_stream_wait( self.backward_node.moe_group.id )
841+
combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
834842
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
835843
output_grad = self.backward_node.mlp_backward(output_grad)
836844
paddle.base.core.nvprof_nvtx_pop()
845+
837846
paddle.base.core.nvprof_nvtx_push("dispatch_forward")
838847
inputs = self.forward_node.dispatch_forward(
839848
inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True
840849
)
841850
paddle.base.core.nvprof_nvtx_pop()
842-
dispatch_forward_event = deep_ep.get_event_from_comm_stream( self.forward_node.moe_group.id )
851+
dispatch_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
843852

844853
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
845854
output_grad = self.backward_node.dispatch_backward(output_grad, async_finish=True)
@@ -851,28 +860,28 @@ def forward_backward(self, inputs, output_grad):
851860
self.backward_node.mlp_backward_dw()
852861
paddle.base.core.nvprof_nvtx_pop()
853862

854-
dispatch_forward_event.calc_stream_wait( self.forward_node.moe_group.id)
863+
dispatch_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
855864
paddle.base.core.nvprof_nvtx_push("mlp_forward")
856865
inputs = self.forward_node.mlp_forward(inputs)
857866
paddle.base.core.nvprof_nvtx_pop()
858867

859868
paddle.base.core.nvprof_nvtx_push("combine_forward")
860869
inputs = self.forward_node.combine_forward(inputs, async_finish=True)
861870
paddle.base.core.nvprof_nvtx_pop()
862-
combine_forward_event = deep_ep.get_event_from_comm_stream( self.forward_node.moe_group.id)
863-
871+
combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
864872

865873
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
866874
paddle.base.core.nvprof_nvtx_push("attn_backward")
867875
output_grad = self.backward_node.attn_backward(output_grad)
876+
event_to_wait = paddle.device.current_stream().record_event()
868877
paddle.base.core.nvprof_nvtx_pop()
869878

870879
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
871880
paddle.base.core.nvprof_nvtx_push("post_process_forward")
872881
inputs = self.forward_node.post_process_forward(inputs)
873882
paddle.base.core.nvprof_nvtx_pop()
874883
paddle.base.core.nvprof_nvtx_pop()
875-
return inputs, output_grad
884+
return inputs, output_grad, event_to_wait
876885

877886

878887
def build_overlapped_nodes(forward_chunk, backward_chunk):
@@ -1579,6 +1588,7 @@ def overlapped_forward_backward(
15791588
backward_loss_fn_node,
15801589
backward_input_grads,
15811590
scaler,
1591+
event_to_wait=None,
15821592
):
15831593
if backward_loss_fn_node is not None:
15841594
if scaler:
@@ -1595,7 +1605,9 @@ def overlapped_forward_backward(
15951605
) = build_overlapped_nodes(forward_chunk, backward_chunk)
15961606
forward_inputs = forward_pre_node.forward(forward_inputs)
15971607
backward_input_grads = backward_pre_node.backward(backward_input_grads)
1598-
forward_inputs, backward_input_grads = overlap_node.forward_backward(forward_inputs, backward_input_grads)
1608+
forward_inputs, backward_input_grads = overlap_node.forward_backward(
1609+
forward_inputs, backward_input_grads, event_to_wait
1610+
)
15991611
forward_inputs = forward_post_node.forward(forward_inputs)
16001612
backward_input_grads = backward_post_node.backward(backward_input_grads)
16011613

paddlenlp/transformers/fused_a2a.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,15 @@ def forward(self, x, group, handle, previous_event=None, async_finish=False):
343343

344344
return combined_x
345345

346-
def backward(self, grad_output, previous_event=None, async_finish=False):
346+
def backward(self, grad_output, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
347347
"""Backward pass of fused combine."""
348348
out = fused_combine_backward_func(
349-
grad_output, self.group, self.handle, previous_event=previous_event, async_finish=async_finish
349+
grad_output,
350+
self.group,
351+
self.handle,
352+
previous_event=previous_event,
353+
async_finish=async_finish,
354+
allocate_on_comm_stream=allocate_on_comm_stream,
350355
)
351356
self.reset_statue()
352357
return out

paddlenlp/transformers/moe_layer.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,16 @@
2727

2828
from ..utils.log import logger
2929
from .fp8_utils import FP8GroupGemmMlpFunctionNode
30-
from .fused_a2a import CombineNode, DispatchNode
30+
from .fused_a2a import CombineNode, DispatchNode, get_buffer, get_hidden_bytes
3131
from .moe_gate import PretrainedMoEGate
3232
from .moe_utils import UnZipNode, ZipNode
3333
from .token_dispatcher import MoEFlexTokenDispatcher, PreDispatchNode
3434

35+
try:
36+
import paddle.distributed.communication.deep_ep as deep_ep
37+
except ImportError:
38+
deep_ep = None
39+
3540
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
3641

3742
DSV3_USE_FP8_GROUP_GEMM = os.getenv("DSV3_USE_FP8_GROUP_GEMM", "False").lower() == "true"
@@ -622,12 +627,13 @@ def forward(self, hidden_states_out, previous_event=None, async_finish=False):
622627
return output_combine
623628

624629
@paddle.no_grad()
625-
def backward(self, output_combine_grad, previous_event=None, async_finish=False):
630+
def backward(self, output_combine_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
626631
# combine grad -> fp8
627632
hidden_states_out_grad = self.combine_node.backward(
628633
output_combine_grad,
629634
previous_event=previous_event,
630635
async_finish=async_finish,
636+
allocate_on_comm_stream=allocate_on_comm_stream,
631637
)
632638
return hidden_states_out_grad
633639

@@ -647,18 +653,29 @@ def forward(self, output_combine):
647653
return output
648654

649655
@paddle.no_grad()
650-
def backward(self, output_grad):
656+
def backward(self, output_grad, event_to_wait=None):
651657
# post combine grad
652-
output_combine_grad = paddle.reshape(output_grad, self.output_combine_shape)
653-
654658
if DSV3_USE_FP8_DISPATCH:
655-
# output_combine_grad quant to fp8
656-
output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
657-
output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False
658-
)
659-
return (output_combine_grad_fp8, output_combine_grad_scale)
659+
if event_to_wait is not None:
660+
buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad))
661+
custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream())
662+
custom_stream.wait_event(event_to_wait)
663+
else:
664+
custom_stream = paddle.device.current_stream()
665+
with paddle.device.stream_guard(custom_stream):
666+
output_combine_grad = paddle.reshape(output_grad, self.output_combine_shape)
667+
# output_combine_grad quant to fp8
668+
output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
669+
output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False
670+
)
671+
output_grad._record_stream()
672+
quant_event = None
673+
if event_to_wait is not None:
674+
quant_event = deep_ep.get_event_from_custom_stream(custom_stream.stream_base)
675+
return (output_combine_grad_fp8, output_combine_grad_scale), quant_event
660676
else:
661-
return output_combine_grad
677+
output_combine_grad = paddle.reshape(output_grad, self.output_combine_shape)
678+
return output_combine_grad, None
662679

663680

664681
class FusionMlpNode:

0 commit comments

Comments
 (0)