Skip to content

Commit 8a5ea21

Browse files
authored
support bw split (#10823)
1 parent bdab888 commit 8a5ea21

File tree

3 files changed

+86
-9
lines changed

3 files changed

+86
-9
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,9 @@ def combine_backward(self, output_grad, async_finish=False):
665665
hidden_states_out_grad,
666666
)
667667

668+
def mlp_backward_dw(self):
669+
self.fp8_fusion_moe_node.mlp_node.backward_dw()
670+
668671
def mlp_backward(self, output_grad):
669672
if self.send_mtp_embed:
670673
(
@@ -681,7 +684,9 @@ def mlp_backward(self, output_grad):
681684
l_aux_grad,
682685
hidden_states_out_grad,
683686
) = output_grad
684-
hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad)
687+
hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(
688+
hidden_states_out_grad, with_dw=False
689+
)
685690

686691
if self.send_mtp_embed:
687692
return (
@@ -790,7 +795,9 @@ def backward(self, output_grad=None, scaler=None):
790795
output_grad = self.post_process_backward(output_grad)
791796
output_grad = self.combine_backward(output_grad)
792797
output_grad = self.mlp_backward(output_grad)
798+
# todo(phlrain): overlap here
793799
output_grad = self.dispatch_backward(output_grad)
800+
self.mlp_backward_dw()
794801
output_grad = self.attn_backward(output_grad)
795802
return output_grad
796803

@@ -820,19 +827,23 @@ def forward_backward(self, inputs, output_grad):
820827

821828
calc_stream_wait(self.backward_node.moe_group.id)
822829
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
823-
paddle.base.core.nvprof_nvtx_push("mlp_backward")
830+
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
824831
output_grad = self.backward_node.mlp_backward(output_grad)
825832
paddle.base.core.nvprof_nvtx_pop()
826833
paddle.base.core.nvprof_nvtx_push("dispatch_forward")
827834
inputs = self.forward_node.dispatch_forward(
828835
inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True
829836
)
830837
paddle.base.core.nvprof_nvtx_pop()
831-
832-
calc_stream_wait(self.forward_node.moe_group.id)
833838
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
834839
output_grad = self.backward_node.dispatch_backward(output_grad, async_finish=True)
835840
paddle.base.core.nvprof_nvtx_pop()
841+
842+
paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw")
843+
self.backward_node.mlp_backward_dw()
844+
paddle.base.core.nvprof_nvtx_pop()
845+
846+
calc_stream_wait(self.forward_node.moe_group.id)
836847
paddle.base.core.nvprof_nvtx_push("mlp_forward")
837848
inputs = self.forward_node.mlp_forward(inputs)
838849
paddle.base.core.nvprof_nvtx_pop()

paddlenlp/transformers/fp8_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
784784
# compute gemm
785785
if isinstance(unzipped_grad, tuple):
786786
(unzipped_grad_fp8, unzipped_grad_scale) = unzipped_grad
787+
unzipped_grad_scale = unzipped_grad_scale.T.contiguous().T
787788
else:
788789
unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
789790
unzipped_grad, output_scale_transpose=True, quant_method="1x128", input_transpose=False
@@ -1015,3 +1016,56 @@ def backward(self, out_grad):
10151016

10161017
self.reset_statue()
10171018
return dx, probs_grad
1019+
1020+
@paddle.no_grad()
1021+
def backward_dx(self, out_grad):
1022+
# recompute expert_w2 and expert_w1
1023+
expert_w1 = [x.w1 for x in self.experts if x is not None]
1024+
expert_w2 = [x.w2 for x in self.experts if x is not None]
1025+
1026+
if self.recompute_fwd_gate_up:
1027+
o1 = self.fwd_gate_up(None, expert_w1, len(expert_w1), self.tokens_per_expert)
1028+
else:
1029+
o1 = self.o1
1030+
1031+
# do2
1032+
do1, o2_s, probs_grad = self.bwd_dowm_input(expert_w2, out_grad, o1, inplace_swiglu_prob=True)
1033+
del o1
1034+
self.o1 = None
1035+
1036+
self.do1 = do1
1037+
self.o2_s = o2_s
1038+
1039+
self.out_grad = out_grad
1040+
1041+
# dx
1042+
dx = self.bwd_gate_up_input(do1, expert_w1, dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad)
1043+
1044+
return dx, probs_grad
1045+
1046+
@paddle.no_grad()
1047+
def backward_dw(self):
1048+
# recompute expert_w2 and expert_w1
1049+
expert_w1 = [x.w1 for x in self.experts if x is not None]
1050+
expert_w2 = [x.w2 for x in self.experts if x is not None]
1051+
1052+
# dw1
1053+
self.bwd_gate_up_weight(self.do1, None, expert_w1, clear_input=True)
1054+
self.input_fp8 = None
1055+
self.input_scale = None
1056+
self.input = None
1057+
self.do1 = None
1058+
1059+
# dw2
1060+
if isinstance(self.out_grad, tuple):
1061+
out_grad_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(self.out_grad[0], self.out_grad[1])
1062+
self.out_grad = None
1063+
self.bwd_down_weight(out_grad_dequant_fp16, self.o2_s, expert_w2)
1064+
del out_grad_dequant_fp16
1065+
else:
1066+
self.bwd_down_weight(self.out_grad, self.o2_s, expert_w2)
1067+
1068+
self.o2_s = None
1069+
1070+
self.reset_statue()
1071+
return

paddlenlp/transformers/moe_layer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
747747
return expert_out_zipped
748748

749749
@paddle.no_grad()
750-
def backward(self, hidden_states_out_grad):
750+
def backward(self, hidden_states_out_grad, with_dw=True):
751751
"""
752752
反向传播函数。
753753
@@ -772,7 +772,10 @@ def backward(self, hidden_states_out_grad):
772772
record_stream_for_multi_input(hidden_states_out_grad)
773773

774774
# expert_grad
775-
expert_out, probs_grad = self.experts_group_gemm_node.backward(unzipped_grad)
775+
if with_dw:
776+
expert_out, probs_grad = self.experts_group_gemm_node.backward(unzipped_grad)
777+
else:
778+
expert_out, probs_grad = self.experts_group_gemm_node.backward_dx(unzipped_grad)
776779

777780
hs_dispatched_grad, dispatched_probs_grad = self.unzip_node.backward(
778781
expert_out,
@@ -781,9 +784,14 @@ def backward(self, hidden_states_out_grad):
781784
self.dispatched_indices,
782785
num_experts=len(self.tokens_per_expert),
783786
)
784-
self.reset_statue()
787+
if with_dw:
788+
self.reset_statue()
785789
return hs_dispatched_grad, dispatched_probs_grad
786790

791+
@paddle.no_grad()
792+
def backward_dw(self):
793+
self.experts_group_gemm_node.backward_dw()
794+
787795

788796
class FusionMoeNode:
789797
def __init__(
@@ -836,11 +844,11 @@ def forward(self, hidden_states, probs, routing_map):
836844
return output
837845

838846
@paddle.no_grad()
839-
def backward(self, output_grad):
847+
def backward(self, output_grad, with_dw=True):
840848
output_combine_grad = self.combine_quant_node.backward(output_grad)
841849
hidden_states_out_grad = self.combine_node.backward(output_combine_grad)
842850

843-
hs_dispatched_grad, dispatched_probs_grad = self.mlp_node.backward(hidden_states_out_grad)
851+
hs_dispatched_grad, dispatched_probs_grad = self.mlp_node.backward(hidden_states_out_grad, with_dw=with_dw)
844852

845853
if DSV3_USE_FP8_DISPATCH:
846854
hs_fp8_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad)
@@ -850,6 +858,10 @@ def backward(self, output_grad):
850858
hs_bf16_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad)
851859
return hs_bf16_grad, None, token_probs_grad
852860

861+
@paddle.no_grad()
862+
def backward_dw(self):
863+
self.mlp_node.backward_dw()
864+
853865

854866
class FusionMoe(paddle.autograd.PyLayer):
855867
@staticmethod

0 commit comments

Comments
 (0)