@@ -498,10 +498,14 @@ def backward(self, output_grad=None, scaler=None):
498
498
499
499
class OverlapedScheduleChunk :
500
500
def __init__ (self , forward_nodes , backward_nodes , use_fuion = True ):
501
- schedule_node_class = OverlapedFUsionScheduleNode if use_fuion else OverlapedScheduleNode
502
501
assert len (forward_nodes ) == len (backward_nodes )
503
502
self .nodes = []
504
503
for f , b in zip (forward_nodes , backward_nodes ):
504
+ schedule_node_class = OverlapedScheduleNode
505
+ if use_fuion :
506
+ schedule_node_class = OverlapedFUsionScheduleNode
507
+ if isinstance (f , DenseDecoderLayerNode ) or isinstance (b , DenseDecoderLayerNode ):
508
+ schedule_node_class = OverlapedDenseFusionScheduleNode
505
509
self .nodes .append (schedule_node_class (f , b , f"OverlapedNode_{ len (self .nodes )} " ))
506
510
507
511
def forward_backward (self , inputs , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
@@ -941,6 +945,29 @@ def backward(self, output_grad=None, scaler=None):
941
945
return output_grad
942
946
943
947
948
+ class DenseDecoderLayerNode (ScheduleNode ):
949
+ def __init__ (
950
+ self ,
951
+ attn_node ,
952
+ mlp_node ,
953
+ name = "DenseDecoderLayerNode" ,
954
+ ):
955
+ super ().__init__ (fwd_func = None , name = name )
956
+ self .attn_node = attn_node
957
+ self .mlp_node = mlp_node
958
+
959
+ def forward (self , inputs ):
960
+ inputs = self .attn_node .forward (inputs )
961
+ inputs = self .mlp_node .forward (inputs )
962
+ return inputs
963
+
964
+ def backward (self , output_grad = None , scaler = None ):
965
+ assert (output_grad is not None ) and (scaler is None )
966
+ output_grad = self .mlp_node .backward (output_grad )
967
+ output_grad = self .attn_node .backward (output_grad )
968
+ return output_grad
969
+
970
+
944
971
class OverlapedFUsionScheduleNode :
945
972
def __init__ (self , forward_node , backward_node , name = "" ):
946
973
assert isinstance (forward_node , FusionFp8DecoderLayerNode ) and isinstance (
@@ -1086,8 +1113,99 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1086
1113
return inputs , output_grad , event_to_wait
1087
1114
1088
1115
1116
+ class OverlapedDenseFusionScheduleNode :
1117
+ def __init__ (self , forward_node , backward_node , name = "" ):
1118
+ assert isinstance (forward_node , FusionFp8DecoderLayerNode ) or isinstance (
1119
+ backward_node , FusionFp8DecoderLayerNode
1120
+ )
1121
+ assert isinstance (forward_node , DenseDecoderLayerNode ) or isinstance (
1122
+ backward_node , DenseDecoderLayerNode
1123
+ )
1124
+ self .forward_node = forward_node
1125
+ self .backward_node = backward_node
1126
+ self .name = name
1127
+
1128
+ def forward_backward (self , inputs , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
1129
+ # Dense forward + MoE backward
1130
+ if isinstance (self .forward_node , DenseDecoderLayerNode ):
1131
+ paddle .base .core .nvprof_nvtx_push ("dense_fw_moe_bw" )
1132
+
1133
+ paddle .base .core .nvprof_nvtx_push ("dense_attn_moe_combine" )
1134
+ output_grad = self .backward_node .post_process_backward (output_grad , combine_bw_event_to_wait )
1135
+ output_grad = self .backward_node .combine_backward (
1136
+ output_grad , previous_event = combine_bw_event_to_wait , async_finish = True , allocate_on_comm_stream = True
1137
+ )
1138
+ combine_bw_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
1139
+ inputs = self .forward_node .attn_node .forward (inputs )
1140
+ combine_bw_event .calc_stream_wait (self .backward_node .moe_group .id )
1141
+ paddle .base .core .nvprof_nvtx_pop () # dense_attn_moe_combine
1142
+
1143
+ paddle .base .core .nvprof_nvtx_push ("moe_mlp" )
1144
+ output_grad = self .backward_node .mlp_backward (output_grad )
1145
+ paddle .base .core .nvprof_nvtx_pop () # moe_mlp
1146
+
1147
+ paddle .base .core .nvprof_nvtx_push ("dense_mlp_moe_dispatch" )
1148
+ output_grad = self .backward_node .dispatch_backward (
1149
+ output_grad , async_finish = True , allocate_on_comm_stream = True
1150
+ )
1151
+ dispatch_bw_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
1152
+ inputs = self .forward_node .mlp_node .forward (inputs )
1153
+ dispatch_bw_event .calc_stream_wait (self .backward_node .moe_group .id )
1154
+ paddle .base .core .nvprof_nvtx_pop () # dense_mlp_moe_dispatch
1155
+
1156
+ paddle .base .core .nvprof_nvtx_push ("moe_attn" )
1157
+ output_grad = self .backward_node .attn_backward (output_grad )
1158
+ paddle .base .core .nvprof_nvtx_pop () # moe_attn
1159
+
1160
+ event_to_wait = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
1161
+ paddle .base .core .nvprof_nvtx_pop () # dense_fw_moe_bw
1162
+
1163
+ # Dense backward + MoE forward
1164
+ else :
1165
+ paddle .base .core .nvprof_nvtx_push ("dense_bw_moe_fw" )
1166
+
1167
+ paddle .base .core .nvprof_nvtx_push ("moe_attn" )
1168
+ inputs = self .forward_node .attn_forward (inputs )
1169
+ attn_fw_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
1170
+ paddle .base .core .nvprof_nvtx_pop () # moe_attn
1171
+
1172
+ paddle .base .core .nvprof_nvtx_push ("dense_mlp_moe_dispatch" )
1173
+ output_grad = self .backward_node .mlp_node .backward (output_grad )
1174
+ inputs = self .forward_node .dispatch_forward (
1175
+ inputs , previous_event = attn_fw_event , async_finish = True , allocate_on_comm_stream = True
1176
+ )
1177
+ dispatch_fw_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
1178
+ dispatch_fw_event .calc_stream_wait (self .forward_node .moe_group .id )
1179
+ paddle .base .core .nvprof_nvtx_pop () # dense_mlp_moe_dispatch
1180
+
1181
+ paddle .base .core .nvprof_nvtx_push ("moe_mlp" )
1182
+ inputs = self .forward_node .mlp_forward (inputs )
1183
+ paddle .base .core .nvprof_nvtx_pop () # moe_mlp
1184
+
1185
+ paddle .base .core .nvprof_nvtx_push ("dense_attn_moe_combine" )
1186
+ inputs = self .forward_node .combine_forward (
1187
+ inputs , async_finish = True , allocate_on_comm_stream = True
1188
+ )
1189
+ combine_fw_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
1190
+ output_grad = self .backward_node .attn_node .backward (output_grad )
1191
+ combine_fw_event .calc_stream_wait (self .forward_node .moe_group .id )
1192
+ paddle .base .core .nvprof_nvtx_pop () # dense_attn_moe_combine
1193
+
1194
+ paddle .base .core .nvprof_nvtx_push ("moe_post" )
1195
+ inputs = self .forward_node .post_process_forward (inputs )
1196
+ paddle .base .core .nvprof_nvtx_pop () # moe_post
1197
+
1198
+ event_to_wait = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
1199
+ paddle .base .core .nvprof_nvtx_pop () # dense_bw_moe_fw
1200
+
1201
+ return inputs , output_grad , event_to_wait
1202
+
1203
+
1089
1204
def build_overlapped_nodes (forward_chunk , backward_chunk ):
1090
- overlap_element_class = FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode
1205
+ overlap_element_class = (
1206
+ FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode ,
1207
+ DenseDecoderLayerNode
1208
+ )
1091
1209
forward_decoder_layer_num = 0
1092
1210
backward_decoder_layer_num = 0
1093
1211
assert isinstance (forward_chunk , ScheduleChunk ) and isinstance (backward_chunk , ScheduleChunk )
@@ -1466,6 +1584,20 @@ def post_process_compute_for_fusion(self, inputs):
1466
1584
1467
1585
return return_args (hidden_states )
1468
1586
1587
+ def attn_compute_dense (self , args ):
1588
+ hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
1589
+ assert attention_mask is None
1590
+ assert attn_mask_startend_row_indices is None
1591
+ assert position_ids is None
1592
+ hidden_states , _ = self .self_attn_compute (hidden_states )
1593
+ return hidden_states
1594
+
1595
+ def mlp_compute_dense (self , hidden_states ):
1596
+ residual = hidden_states
1597
+ hidden_states = self .mlp (hidden_states )
1598
+ hidden_states = residual + hidden_states
1599
+ return hidden_states
1600
+
1469
1601
def build_schedule_node (self ):
1470
1602
if isinstance (self .mlp , DeepseekV2MoE ):
1471
1603
self .mlp .update_flex_token ()
@@ -1515,7 +1647,14 @@ def build_schedule_node(self):
1515
1647
mlp_layer = self .mlp ,
1516
1648
name = "DecoderLayerNode" ,
1517
1649
)
1518
- return ScheduleNode (self .forward , name = "DeepseekV2DecoderLayerPipe" )
1650
+
1651
+ attn_node = ScheduleNode (self .attn_compute_dense , name = "attn_node" )
1652
+ mlp_node = ScheduleNode (self .mlp_compute_dense , name = "mlp_node" )
1653
+ return DenseDecoderLayerNode (
1654
+ attn_node = attn_node ,
1655
+ mlp_node = mlp_node ,
1656
+ name = "DenseDecoderLayerNode" ,
1657
+ )
1519
1658
1520
1659
1521
1660
class DeepseekV2MTPLayerPipe (DeepseekV2MTPLayer ):
0 commit comments