Skip to content

Commit e661a3e

Browse files
phlrainphlrain
andauthored
optimize mlp bw split (#10931)
Co-authored-by: phlrain <[email protected]>
1 parent b5ebfdd commit e661a3e

File tree

3 files changed

+111
-91
lines changed

3 files changed

+111
-91
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
ScheduleNode,
2828
SharedLayerDesc,
2929
)
30-
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import (
31-
WeightGradStore
32-
)
30+
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore
3331

3432
try:
3533
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore
@@ -721,17 +719,12 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
721719
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
722720
return ret
723721

724-
def mlp_backward_dw(self):
725-
self.fp8_fusion_moe_node.mlp_node.backward_dw()
726-
727722
def mlp_backward(self, output_grad):
728723
if self.send_mtp_embed:
729724
inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad
730725
else:
731726
hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad
732-
hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(
733-
hidden_states_out_grad, with_dw=False
734-
)
727+
hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad)
735728

736729
ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad)
737730
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
@@ -809,7 +802,6 @@ def backward(self, output_grad=None, scaler=None):
809802
output_grad = self.mlp_backward(output_grad)
810803
# todo(phlrain): overlap here
811804
output_grad = self.dispatch_backward(output_grad)
812-
self.mlp_backward_dw()
813805
output_grad = self.attn_backward(output_grad)
814806
return output_grad
815807

@@ -853,7 +845,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
853845

854846
combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
855847
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
848+
assert WeightGradStore.funcs_queue.empty()
849+
WeightGradStore.enabled = True
856850
output_grad = self.backward_node.mlp_backward(output_grad)
851+
WeightGradStore.enabled = False
852+
WeightGradStore.flush()
857853
paddle.base.core.nvprof_nvtx_pop()
858854

859855
output_grad_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
@@ -889,7 +885,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
889885
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
890886
paddle.base.core.nvprof_nvtx_pop()
891887

892-
893888
final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
894889

895890
paddle.base.core.nvprof_nvtx_push("combine_forward")
@@ -922,13 +917,13 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
922917
paddle.base.core.nvprof_nvtx_pop()
923918

924919
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
925-
920+
926921
paddle.base.core.nvprof_nvtx_push("attn_backward")
927922
assert WeightGradStore.funcs_queue.empty()
928923
WeightGradStore.enabled = True
929924
output_grad = self.backward_node.attn_backward(output_grad)
930925
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
931-
926+
932927
if EventStore is not None:
933928
EventStore.set(event_to_wait)
934929

@@ -937,16 +932,12 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
937932
WeightGradStore.pop()
938933
assert WeightGradStore.funcs_queue.empty()
939934

940-
WeightGradStore.enabled = False
941-
WeightGradStore.flush()
942-
WeightGradStore.pop()
943-
assert WeightGradStore.funcs_queue.empty()
944935
paddle.base.core.nvprof_nvtx_pop()
945936

946937
# residual add
947938
if pp_stream is None:
948939
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
949-
940+
950941
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
951942
if final_out.shape[-1] != combine_fwd_out.shape[-1]:
952943
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加

0 commit comments

Comments
 (0)