Skip to content

Commit e91f55a

Browse files
authored
Implement overlapping of Dense and MoE DecoderLayers (PaddlePaddle#10959)
1 parent 8426847 commit e91f55a

File tree

1 file changed

+142
-3
lines changed

1 file changed

+142
-3
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,14 @@ def backward(self, output_grad=None, scaler=None):
498498

499499
class OverlapedScheduleChunk:
500500
def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
501-
schedule_node_class = OverlapedFUsionScheduleNode if use_fuion else OverlapedScheduleNode
502501
assert len(forward_nodes) == len(backward_nodes)
503502
self.nodes = []
504503
for f, b in zip(forward_nodes, backward_nodes):
504+
schedule_node_class = OverlapedScheduleNode
505+
if use_fuion:
506+
schedule_node_class = OverlapedFUsionScheduleNode
507+
if isinstance(f, DenseDecoderLayerNode) or isinstance(b, DenseDecoderLayerNode):
508+
schedule_node_class = OverlapedDenseFusionScheduleNode
505509
self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}"))
506510

507511
def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
@@ -941,6 +945,29 @@ def backward(self, output_grad=None, scaler=None):
941945
return output_grad
942946

943947

948+
class DenseDecoderLayerNode(ScheduleNode):
949+
def __init__(
950+
self,
951+
attn_node,
952+
mlp_node,
953+
name="DenseDecoderLayerNode",
954+
):
955+
super().__init__(fwd_func=None, name=name)
956+
self.attn_node = attn_node
957+
self.mlp_node = mlp_node
958+
959+
def forward(self, inputs):
960+
inputs = self.attn_node.forward(inputs)
961+
inputs = self.mlp_node.forward(inputs)
962+
return inputs
963+
964+
def backward(self, output_grad=None, scaler=None):
965+
assert (output_grad is not None) and (scaler is None)
966+
output_grad = self.mlp_node.backward(output_grad)
967+
output_grad = self.attn_node.backward(output_grad)
968+
return output_grad
969+
970+
944971
class OverlapedFUsionScheduleNode:
945972
def __init__(self, forward_node, backward_node, name=""):
946973
assert isinstance(forward_node, FusionFp8DecoderLayerNode) and isinstance(
@@ -1086,8 +1113,99 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
10861113
return inputs, output_grad, event_to_wait
10871114

10881115

1116+
class OverlapedDenseFusionScheduleNode:
1117+
def __init__(self, forward_node, backward_node, name=""):
1118+
assert isinstance(forward_node, FusionFp8DecoderLayerNode) or isinstance(
1119+
backward_node, FusionFp8DecoderLayerNode
1120+
)
1121+
assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance(
1122+
backward_node, DenseDecoderLayerNode
1123+
)
1124+
self.forward_node = forward_node
1125+
self.backward_node = backward_node
1126+
self.name = name
1127+
1128+
def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
1129+
# Dense forward + MoE backward
1130+
if isinstance(self.forward_node, DenseDecoderLayerNode):
1131+
paddle.base.core.nvprof_nvtx_push("dense_fw_moe_bw")
1132+
1133+
paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine")
1134+
output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait)
1135+
output_grad = self.backward_node.combine_backward(
1136+
output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True
1137+
)
1138+
combine_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
1139+
inputs = self.forward_node.attn_node.forward(inputs)
1140+
combine_bw_event.calc_stream_wait(self.backward_node.moe_group.id)
1141+
paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine
1142+
1143+
paddle.base.core.nvprof_nvtx_push("moe_mlp")
1144+
output_grad = self.backward_node.mlp_backward(output_grad)
1145+
paddle.base.core.nvprof_nvtx_pop() # moe_mlp
1146+
1147+
paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch")
1148+
output_grad = self.backward_node.dispatch_backward(
1149+
output_grad, async_finish=True, allocate_on_comm_stream=True
1150+
)
1151+
dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
1152+
inputs = self.forward_node.mlp_node.forward(inputs)
1153+
dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id)
1154+
paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch
1155+
1156+
paddle.base.core.nvprof_nvtx_push("moe_attn")
1157+
output_grad = self.backward_node.attn_backward(output_grad)
1158+
paddle.base.core.nvprof_nvtx_pop() # moe_attn
1159+
1160+
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
1161+
paddle.base.core.nvprof_nvtx_pop() # dense_fw_moe_bw
1162+
1163+
# Dense backward + MoE forward
1164+
else:
1165+
paddle.base.core.nvprof_nvtx_push("dense_bw_moe_fw")
1166+
1167+
paddle.base.core.nvprof_nvtx_push("moe_attn")
1168+
inputs = self.forward_node.attn_forward(inputs)
1169+
attn_fw_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
1170+
paddle.base.core.nvprof_nvtx_pop() # moe_attn
1171+
1172+
paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch")
1173+
output_grad = self.backward_node.mlp_node.backward(output_grad)
1174+
inputs = self.forward_node.dispatch_forward(
1175+
inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True
1176+
)
1177+
dispatch_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
1178+
dispatch_fw_event.calc_stream_wait(self.forward_node.moe_group.id)
1179+
paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch
1180+
1181+
paddle.base.core.nvprof_nvtx_push("moe_mlp")
1182+
inputs = self.forward_node.mlp_forward(inputs)
1183+
paddle.base.core.nvprof_nvtx_pop() # moe_mlp
1184+
1185+
paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine")
1186+
inputs = self.forward_node.combine_forward(
1187+
inputs, async_finish=True, allocate_on_comm_stream=True
1188+
)
1189+
combine_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
1190+
output_grad = self.backward_node.attn_node.backward(output_grad)
1191+
combine_fw_event.calc_stream_wait(self.forward_node.moe_group.id)
1192+
paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine
1193+
1194+
paddle.base.core.nvprof_nvtx_push("moe_post")
1195+
inputs = self.forward_node.post_process_forward(inputs)
1196+
paddle.base.core.nvprof_nvtx_pop() # moe_post
1197+
1198+
event_to_wait = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
1199+
paddle.base.core.nvprof_nvtx_pop() # dense_bw_moe_fw
1200+
1201+
return inputs, output_grad, event_to_wait
1202+
1203+
10891204
def build_overlapped_nodes(forward_chunk, backward_chunk):
1090-
overlap_element_class = FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode
1205+
overlap_element_class = (
1206+
FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode,
1207+
DenseDecoderLayerNode
1208+
)
10911209
forward_decoder_layer_num = 0
10921210
backward_decoder_layer_num = 0
10931211
assert isinstance(forward_chunk, ScheduleChunk) and isinstance(backward_chunk, ScheduleChunk)
@@ -1466,6 +1584,20 @@ def post_process_compute_for_fusion(self, inputs):
14661584

14671585
return return_args(hidden_states)
14681586

1587+
def attn_compute_dense(self, args):
1588+
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
1589+
assert attention_mask is None
1590+
assert attn_mask_startend_row_indices is None
1591+
assert position_ids is None
1592+
hidden_states, _ = self.self_attn_compute(hidden_states)
1593+
return hidden_states
1594+
1595+
def mlp_compute_dense(self, hidden_states):
1596+
residual = hidden_states
1597+
hidden_states = self.mlp(hidden_states)
1598+
hidden_states = residual + hidden_states
1599+
return hidden_states
1600+
14691601
def build_schedule_node(self):
14701602
if isinstance(self.mlp, DeepseekV2MoE):
14711603
self.mlp.update_flex_token()
@@ -1515,7 +1647,14 @@ def build_schedule_node(self):
15151647
mlp_layer=self.mlp,
15161648
name="DecoderLayerNode",
15171649
)
1518-
return ScheduleNode(self.forward, name="DeepseekV2DecoderLayerPipe")
1650+
1651+
attn_node = ScheduleNode(self.attn_compute_dense, name="attn_node")
1652+
mlp_node = ScheduleNode(self.mlp_compute_dense, name="mlp_node")
1653+
return DenseDecoderLayerNode(
1654+
attn_node=attn_node,
1655+
mlp_node=mlp_node,
1656+
name="DenseDecoderLayerNode",
1657+
)
15191658

15201659

15211660
class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):

0 commit comments

Comments
 (0)