Skip to content

Commit 0e4be4d

Browse files
authored
Add adaptived recomptue of o1 (#10891)
* add adaptived recomptue of O1 * add adaptived recomptue ratio * add adaptived recomptue ratio
1 parent c7b0059 commit 0e4be4d

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
recompute_fwd_gate_up=0,
186186
is_split_group_gemm=False,
187187
fakse_gate_restrict_balance=False,
188+
adaptive_remained_O1_recompute_ratio=0,
188189
**kwargs,
189190
):
190191
self.vocab_size = vocab_size
@@ -239,6 +240,7 @@ def __init__(
239240
self.recompute_fwd_gate_up = recompute_fwd_gate_up
240241
self.is_split_group_gemm = is_split_group_gemm
241242
self.fakse_gate_restrict_balance = fakse_gate_restrict_balance
243+
self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio
242244

243245
super().__init__(
244246
pad_token_id=pad_token_id,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,9 +1189,15 @@ def build_schedule_node(self):
11891189
if self.mlp.using_flex_token:
11901190
if DSV3_USE_FP8_GEMM:
11911191
attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node")
1192+
1193+
recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0
1194+
recompute_fwd_gate_up_ = (
1195+
-1 if self.config.adaptive_remained_O1_recompute_ratio else recompute_fwd_gate_up_
1196+
)
1197+
11921198
fp8_fusion_moe_node = FusionMoeNode(
11931199
self.mlp,
1194-
recompute_fwd_gate_up=self.config.recompute_fwd_gate_up,
1200+
recompute_fwd_gate_up=recompute_fwd_gate_up_,
11951201
is_split_group_gemm=self.config.is_split_group_gemm,
11961202
name="fp8_fusion_moe_node",
11971203
)

paddlenlp/transformers/moe_layer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,20 @@ def __init__(self, custom_map, max_topk, recompute_fwd_gate_up=False, is_split_g
704704
recompute_fwd_gate_up=recompute_fwd_gate_up,
705705
is_split_group_gemm=is_split_group_gemm,
706706
)
707+
708+
self.seq_length = custom_map.config.seq_length
709+
self.num_experts_per_tok = custom_map.config.num_experts_per_tok
710+
self.adaptive_remained_O1_recompute_ratio = custom_map.config.adaptive_remained_O1_recompute_ratio
711+
712+
self.recompute_fwd_gate_up = recompute_fwd_gate_up
707713
self.dispatched_indices = None
708714
self.dispatched_probs = None
709715
self.tokens_per_expert = None
710716
self.router_topk = max_topk
711717

718+
def set_recompute_fwd_gate_up(self, recompute_fwd_gate_up):
719+
self.experts_group_gemm_node.recompute_fwd_gate_up = recompute_fwd_gate_up
720+
712721
def reset_statue(self, with_dw=False):
713722
"""
714723
重置所有状态变量。
@@ -771,6 +780,18 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
771780
dispatched_indices._record_stream()
772781
dispatched_probs._record_stream()
773782

783+
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
784+
if self.recompute_fwd_gate_up == -1:
785+
if (
786+
unzipped_tokens.shape[0]
787+
> self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio
788+
):
789+
# logger.debug(f"recompute_fwd_gate_up changed to True, Because the receives {unzipped_tokens.shape[0]} Tensors greater then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.")
790+
self.set_recompute_fwd_gate_up(True)
791+
else:
792+
# logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.")
793+
self.set_recompute_fwd_gate_up(False)
794+
774795
# 2 experts
775796
padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert]
776797
expert_out = self.experts_group_gemm_node.forward(
@@ -792,6 +813,16 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
792813
dispatched_indices._record_stream()
793814
dispatched_probs._record_stream()
794815

816+
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
817+
if self.recompute_fwd_gate_up == -1:
818+
if (
819+
unzipped_tokens.shape[0]
820+
> self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio
821+
):
822+
self.set_recompute_fwd_gate_up(True)
823+
else:
824+
self.set_recompute_fwd_gate_up(False)
825+
795826
# 2 experts
796827
padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert]
797828
expert_out = self.experts_group_gemm_node.forward(

0 commit comments

Comments
 (0)