Skip to content

Commit 0e9051a

Browse files
authored
Add sequence and output dimension subbatch (#10947)
* add subbatch * remove backward_dw * deal 0 size * fix subbatch bug * fix cuda 700 * polish code * fix nan
1 parent 6a3fb15 commit 0e9051a

File tree

6 files changed

+1064
-214
lines changed

6 files changed

+1064
-214
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def __init__(
188188
adaptive_remained_O1_recompute_ratio=0,
189189
offline_quant_expert_weight=True,
190190
clear_origin_weight_when_offline_quant=True,
191+
mlp_bwd_subbatch_rows=0,
192+
mlp_fwd_subbatch_rows=0,
193+
output_subbatch_rows=0,
191194
**kwargs,
192195
):
193196
self.vocab_size = vocab_size
@@ -245,6 +248,9 @@ def __init__(
245248
self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio
246249
self.offline_quant_expert_weight = offline_quant_expert_weight
247250
self.clear_origin_weight_when_offline_quant = clear_origin_weight_when_offline_quant
251+
self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows
252+
self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows
253+
self.output_subbatch_rows = output_subbatch_rows
248254

249255
super().__init__(
250256
pad_token_id=pad_token_id,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,15 +1641,18 @@ def build_schedule_node(self):
16411641
if DSV3_USE_FP8_GEMM:
16421642
attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node")
16431643

1644+
# recompute_fwd_gate_up_ may be 1, 0 or -1, 1 means recompute, 0 means disable recompute, -1 means adaptive recompute.
16441645
recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0
1645-
recompute_fwd_gate_up_ = (
1646-
-1 if self.config.adaptive_remained_O1_recompute_ratio else recompute_fwd_gate_up_
1647-
)
1646+
if recompute_fwd_gate_up_ == 0 and self.config.adaptive_remained_O1_recompute_ratio:
1647+
recompute_fwd_gate_up_ = -1
16481648

16491649
fp8_fusion_moe_node = FusionMoeNode(
16501650
self.mlp,
16511651
recompute_fwd_gate_up=recompute_fwd_gate_up_,
16521652
is_split_group_gemm=self.config.is_split_group_gemm,
1653+
mlp_fwd_subbatch_rows=self.config.mlp_fwd_subbatch_rows,
1654+
mlp_bwd_subbatch_rows=self.config.mlp_bwd_subbatch_rows,
1655+
output_subbatch_rows=self.config.output_subbatch_rows,
16531656
name="fp8_fusion_moe_node",
16541657
)
16551658
post_process_node = PostProcessNode(

0 commit comments

Comments
 (0)