@@ -540,6 +540,20 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
540
540
return inputs , output_grad , None
541
541
542
542
543
+ class DecoderBackwardScheduleChunk :
544
+ def __init__ (self , nodes ):
545
+ self .nodes = nodes
546
+
547
+ def backward (self , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
548
+ event_to_wait = combine_bw_event_to_wait
549
+ for i , n in enumerate (self .nodes ):
550
+ pp_stream_t = pp_stream if i + 1 == len (self .nodes ) else None
551
+ output_grad , event_to_wait = n .backward_for_fusion (
552
+ output_grad , combine_bw_event_to_wait = event_to_wait , pp_stream = pp_stream_t
553
+ )
554
+ return output_grad
555
+
556
+
543
557
class OverlapedScheduleNode :
544
558
def __init__ (self , forward_node , backward_node , name = "" ):
545
559
assert isinstance (forward_node , DecoderLayerNode ) and isinstance (backward_node , DecoderLayerNode )
@@ -972,6 +986,77 @@ def attn_backward(self, output_grad):
972
986
output_grad = self .attn_and_gate_node .backward (output_grad )
973
987
return output_grad
974
988
989
+ def backward_for_fusion (self , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
990
+ paddle .base .core .nvprof_nvtx_push ("backward" )
991
+ if combine_bw_event_to_wait is None :
992
+ combine_bw_event_to_wait = deep_ep .get_event_from_calc_stream (self .moe_group .id )
993
+
994
+ paddle .base .core .nvprof_nvtx_push ("post_process_backward" )
995
+ output_grad = self .post_process_backward (output_grad , combine_bw_event_to_wait )
996
+ paddle .base .core .nvprof_nvtx_pop ()
997
+
998
+ paddle .base .core .nvprof_nvtx_push ("combine_backward" )
999
+ output_grad = self .combine_backward (
1000
+ output_grad , previous_event = combine_bw_event_to_wait , async_finish = True , allocate_on_comm_stream = True
1001
+ )
1002
+ combine_backward_event = deep_ep .get_event_from_comm_stream (self .moe_group .id )
1003
+ combine_backward_event .calc_stream_wait (self .moe_group .id )
1004
+ paddle .base .core .nvprof_nvtx_pop ()
1005
+
1006
+ if WeightGradStore .enabled :
1007
+ paddle .base .core .nvprof_nvtx_push ("mlp_backward" )
1008
+ output_grad = self .mlp_backward (output_grad )
1009
+ paddle .base .core .nvprof_nvtx_pop ()
1010
+
1011
+ paddle .base .core .nvprof_nvtx_push ("dispatch_backward" )
1012
+ output_grad = self .dispatch_backward (output_grad )
1013
+ paddle .base .core .nvprof_nvtx_pop ()
1014
+
1015
+ paddle .base .core .nvprof_nvtx_push ("attn_backward" )
1016
+ output_grad = self .attn_backward (output_grad )
1017
+ paddle .base .core .nvprof_nvtx_pop ()
1018
+
1019
+ event_to_wait = None
1020
+
1021
+ else :
1022
+ paddle .base .core .nvprof_nvtx_push ("mlp_backward_dx" )
1023
+ assert WeightGradStore .funcs_queue .empty ()
1024
+ WeightGradStore .enabled = True
1025
+ output_grad = self .mlp_backward (output_grad )
1026
+ WeightGradStore .enabled = False
1027
+ WeightGradStore .flush ()
1028
+ output_grad_event = deep_ep .get_event_from_calc_stream (self .moe_group .id )
1029
+ paddle .base .core .nvprof_nvtx_pop ()
1030
+
1031
+ paddle .base .core .nvprof_nvtx_push ("dispatch_backward" )
1032
+ output_grad = self .dispatch_backward (
1033
+ output_grad , async_finish = True , previous_event = output_grad_event , allocate_on_comm_stream = True
1034
+ )
1035
+ dispatch_backward_event = deep_ep .get_event_from_comm_stream (self .moe_group .id )
1036
+ paddle .base .core .nvprof_nvtx_pop ()
1037
+
1038
+ paddle .base .core .nvprof_nvtx_push ("mlp_backward_dw" )
1039
+ WeightGradStore .pop ()
1040
+ assert WeightGradStore .funcs_queue .empty ()
1041
+ paddle .base .core .nvprof_nvtx_pop ()
1042
+
1043
+ paddle .base .core .nvprof_nvtx_push ("attn_backward_dx" )
1044
+ dispatch_backward_event .calc_stream_wait (self .moe_group .id )
1045
+ WeightGradStore .enabled = True
1046
+ output_grad = self .attn_backward (output_grad )
1047
+ WeightGradStore .enabled = False
1048
+ WeightGradStore .flush ()
1049
+ event_to_wait = deep_ep .get_event_from_calc_stream (self .moe_group .id )
1050
+ paddle .base .core .nvprof_nvtx_pop ()
1051
+
1052
+ paddle .base .core .nvprof_nvtx_push ("attn_backward_dw" )
1053
+ WeightGradStore .pop ()
1054
+ assert WeightGradStore .funcs_queue .empty ()
1055
+ paddle .base .core .nvprof_nvtx_pop ()
1056
+
1057
+ paddle .base .core .nvprof_nvtx_pop ()
1058
+ return output_grad , event_to_wait
1059
+
975
1060
def forward (self , inputs ):
976
1061
inputs = self .attn_forward (inputs )
977
1062
inputs = self .dispatch_forward (inputs )
@@ -1310,6 +1395,11 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
1310
1395
backward_pre_node = ScheduleChunk (list (reversed (backward_pre_overlap_layers )))
1311
1396
backward_post_node = ScheduleChunk (list (reversed (backward_post_overlap_layers )))
1312
1397
1398
+ if not forward_chunk .nodes and all (
1399
+ isinstance (n , FusionFp8DecoderLayerNode ) for n in backward_chunk .nodes
1400
+ ):
1401
+ backward_post_node = DecoderBackwardScheduleChunk (backward_post_overlap_layers )
1402
+
1313
1403
overlap_node = OverlapedScheduleChunk (forward_overlap_layers , backward_overlap_layers , use_fuion = DSV3_USE_FP8_GEMM )
1314
1404
return forward_pre_node , backward_pre_node , overlap_node , forward_post_node , backward_post_node
1315
1405
0 commit comments