33
33
from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import EventStore
34
34
except ImportError :
35
35
EventStore = None
36
+
36
37
from paddle .distributed .fleet .recompute .recompute import recompute
37
38
from paddle .distributed .fleet .utils .sequence_parallel_utils import ScatterOp
38
39
@@ -598,6 +599,7 @@ def __init__(
598
599
mlp_layer ,
599
600
send_mtp_embed ,
600
601
using_post_norm_recompute = False ,
602
+ stepped_recompute_fwd_gate_up = False ,
601
603
name = "" ,
602
604
):
603
605
self .attn_and_gate_node = attn_and_gate_node
@@ -606,6 +608,7 @@ def __init__(
606
608
self .send_mtp_embed = send_mtp_embed
607
609
608
610
self .using_post_norm_recompute = using_post_norm_recompute
611
+ self .stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up
609
612
self .name = name
610
613
611
614
self .moe_group = mlp_layer .moe_group
@@ -1058,6 +1061,8 @@ def backward_for_fusion(self, output_grad, combine_bw_event_to_wait=None, pp_str
1058
1061
return output_grad , event_to_wait
1059
1062
1060
1063
def forward (self , inputs ):
1064
+ if self .stepped_recompute_fwd_gate_up :
1065
+ self .fp8_fusion_moe_node .mlp_node .set_recompute_fwd_gate_up (True )
1061
1066
inputs = self .attn_forward (inputs )
1062
1067
inputs = self .dispatch_forward (inputs )
1063
1068
inputs = self .mlp_forward (inputs )
@@ -1820,6 +1825,7 @@ def build_schedule_node(self):
1820
1825
mlp_layer = self .mlp ,
1821
1826
send_mtp_embed = self .config .send_mtp_embed ,
1822
1827
using_post_norm_recompute = self .config .using_post_norm_recompute ,
1828
+ stepped_recompute_fwd_gate_up = self .config .stepped_recompute_fwd_gate_up ,
1823
1829
name = "FusionFp8DecoderLayerNode" ,
1824
1830
)
1825
1831
else :
0 commit comments