Skip to content

Commit 5b4855d

Browse files
authored
Add stepped O1 recompute (#11010)
* add_stepped_rc * polish code * remove stepevent * fix bug
1 parent 79c421f commit 5b4855d

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(
182182
use_dualpipev=False,
183183
send_mtp_embed=False,
184184
using_post_norm_recompute=False,
185+
stepped_recompute_fwd_gate_up=False,
185186
recompute_fwd_gate_up=0,
186187
recompute_fa3=0,
187188
is_split_group_gemm=False,
@@ -245,6 +246,7 @@ def __init__(
245246
self.using_post_norm_recompute = using_post_norm_recompute
246247
self.recompute_fwd_gate_up = recompute_fwd_gate_up
247248
self.recompute_fa3 = recompute_fa3
249+
self.stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up
248250
self.is_split_group_gemm = is_split_group_gemm
249251
self.fakse_gate_restrict_balance = fakse_gate_restrict_balance
250252
self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore
3434
except ImportError:
3535
EventStore = None
36+
3637
from paddle.distributed.fleet.recompute.recompute import recompute
3738
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
3839

@@ -598,6 +599,7 @@ def __init__(
598599
mlp_layer,
599600
send_mtp_embed,
600601
using_post_norm_recompute=False,
602+
stepped_recompute_fwd_gate_up=False,
601603
name="",
602604
):
603605
self.attn_and_gate_node = attn_and_gate_node
@@ -606,6 +608,7 @@ def __init__(
606608
self.send_mtp_embed = send_mtp_embed
607609

608610
self.using_post_norm_recompute = using_post_norm_recompute
611+
self.stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up
609612
self.name = name
610613

611614
self.moe_group = mlp_layer.moe_group
@@ -1058,6 +1061,8 @@ def backward_for_fusion(self, output_grad, combine_bw_event_to_wait=None, pp_str
10581061
return output_grad, event_to_wait
10591062

10601063
def forward(self, inputs):
1064+
if self.stepped_recompute_fwd_gate_up:
1065+
self.fp8_fusion_moe_node.mlp_node.set_recompute_fwd_gate_up(True)
10611066
inputs = self.attn_forward(inputs)
10621067
inputs = self.dispatch_forward(inputs)
10631068
inputs = self.mlp_forward(inputs)
@@ -1820,6 +1825,7 @@ def build_schedule_node(self):
18201825
mlp_layer=self.mlp,
18211826
send_mtp_embed=self.config.send_mtp_embed,
18221827
using_post_norm_recompute=self.config.using_post_norm_recompute,
1828+
stepped_recompute_fwd_gate_up=self.config.stepped_recompute_fwd_gate_up,
18231829
name="FusionFp8DecoderLayerNode",
18241830
)
18251831
else:

0 commit comments

Comments
 (0)