Skip to content

Commit 475942a

Browse files
authored
Implement Decoder backward self-overlapping (#11003)
1 parent de9c524 commit 475942a

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,20 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
540540
return inputs, output_grad, None
541541

542542

543+
class DecoderBackwardScheduleChunk:
544+
def __init__(self, nodes):
545+
self.nodes = nodes
546+
547+
def backward(self, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
548+
event_to_wait = combine_bw_event_to_wait
549+
for i, n in enumerate(self.nodes):
550+
pp_stream_t = pp_stream if i + 1 == len(self.nodes) else None
551+
output_grad, event_to_wait = n.backward_for_fusion(
552+
output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t
553+
)
554+
return output_grad
555+
556+
543557
class OverlapedScheduleNode:
544558
def __init__(self, forward_node, backward_node, name=""):
545559
assert isinstance(forward_node, DecoderLayerNode) and isinstance(backward_node, DecoderLayerNode)
@@ -972,6 +986,77 @@ def attn_backward(self, output_grad):
972986
output_grad = self.attn_and_gate_node.backward(output_grad)
973987
return output_grad
974988

989+
def backward_for_fusion(self, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
990+
paddle.base.core.nvprof_nvtx_push("backward")
991+
if combine_bw_event_to_wait is None:
992+
combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.moe_group.id)
993+
994+
paddle.base.core.nvprof_nvtx_push("post_process_backward")
995+
output_grad = self.post_process_backward(output_grad, combine_bw_event_to_wait)
996+
paddle.base.core.nvprof_nvtx_pop()
997+
998+
paddle.base.core.nvprof_nvtx_push("combine_backward")
999+
output_grad = self.combine_backward(
1000+
output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True
1001+
)
1002+
combine_backward_event = deep_ep.get_event_from_comm_stream(self.moe_group.id)
1003+
combine_backward_event.calc_stream_wait(self.moe_group.id)
1004+
paddle.base.core.nvprof_nvtx_pop()
1005+
1006+
if WeightGradStore.enabled:
1007+
paddle.base.core.nvprof_nvtx_push("mlp_backward")
1008+
output_grad = self.mlp_backward(output_grad)
1009+
paddle.base.core.nvprof_nvtx_pop()
1010+
1011+
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
1012+
output_grad = self.dispatch_backward(output_grad)
1013+
paddle.base.core.nvprof_nvtx_pop()
1014+
1015+
paddle.base.core.nvprof_nvtx_push("attn_backward")
1016+
output_grad = self.attn_backward(output_grad)
1017+
paddle.base.core.nvprof_nvtx_pop()
1018+
1019+
event_to_wait = None
1020+
1021+
else:
1022+
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
1023+
assert WeightGradStore.funcs_queue.empty()
1024+
WeightGradStore.enabled = True
1025+
output_grad = self.mlp_backward(output_grad)
1026+
WeightGradStore.enabled = False
1027+
WeightGradStore.flush()
1028+
output_grad_event = deep_ep.get_event_from_calc_stream(self.moe_group.id)
1029+
paddle.base.core.nvprof_nvtx_pop()
1030+
1031+
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
1032+
output_grad = self.dispatch_backward(
1033+
output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True
1034+
)
1035+
dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.moe_group.id)
1036+
paddle.base.core.nvprof_nvtx_pop()
1037+
1038+
paddle.base.core.nvprof_nvtx_push("mlp_backward_dw")
1039+
WeightGradStore.pop()
1040+
assert WeightGradStore.funcs_queue.empty()
1041+
paddle.base.core.nvprof_nvtx_pop()
1042+
1043+
paddle.base.core.nvprof_nvtx_push("attn_backward_dx")
1044+
dispatch_backward_event.calc_stream_wait(self.moe_group.id)
1045+
WeightGradStore.enabled = True
1046+
output_grad = self.attn_backward(output_grad)
1047+
WeightGradStore.enabled = False
1048+
WeightGradStore.flush()
1049+
event_to_wait = deep_ep.get_event_from_calc_stream(self.moe_group.id)
1050+
paddle.base.core.nvprof_nvtx_pop()
1051+
1052+
paddle.base.core.nvprof_nvtx_push("attn_backward_dw")
1053+
WeightGradStore.pop()
1054+
assert WeightGradStore.funcs_queue.empty()
1055+
paddle.base.core.nvprof_nvtx_pop()
1056+
1057+
paddle.base.core.nvprof_nvtx_pop()
1058+
return output_grad, event_to_wait
1059+
9751060
def forward(self, inputs):
9761061
inputs = self.attn_forward(inputs)
9771062
inputs = self.dispatch_forward(inputs)
@@ -1310,6 +1395,11 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
13101395
backward_pre_node = ScheduleChunk(list(reversed(backward_pre_overlap_layers)))
13111396
backward_post_node = ScheduleChunk(list(reversed(backward_post_overlap_layers)))
13121397

1398+
if not forward_chunk.nodes and all(
1399+
isinstance(n, FusionFp8DecoderLayerNode) for n in backward_chunk.nodes
1400+
):
1401+
backward_post_node = DecoderBackwardScheduleChunk(backward_post_overlap_layers)
1402+
13131403
overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM)
13141404
return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node
13151405

0 commit comments

Comments
 (0)