Skip to content

Commit 7ce548e

Browse files
Fa3 recompute (#10994)
* fix * add fa3 rc --------- Co-authored-by: zhangbo9674 <[email protected]>
1 parent 1ab02cd commit 7ce548e

File tree

3 files changed

+115
-24
lines changed

3 files changed

+115
-24
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
send_mtp_embed=False,
184184
using_post_norm_recompute=False,
185185
recompute_fwd_gate_up=0,
186+
recompute_fa3=0,
186187
is_split_group_gemm=False,
187188
fakse_gate_restrict_balance=False,
188189
adaptive_remained_O1_recompute_ratio=0,
@@ -243,6 +244,7 @@ def __init__(
243244
self.send_mtp_embed = send_mtp_embed
244245
self.using_post_norm_recompute = using_post_norm_recompute
245246
self.recompute_fwd_gate_up = recompute_fwd_gate_up
247+
self.recompute_fa3 = recompute_fa3
246248
self.is_split_group_gemm = is_split_group_gemm
247249
self.fakse_gate_restrict_balance = fakse_gate_restrict_balance
248250
self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,7 @@ def forward(
13341334
eps,
13351335
kv_lora_rank,
13361336
softmax_scale,
1337+
recompute_fa3=False,
13371338
):
13381339

13391340
bsz = q_init.shape[0]
@@ -1439,26 +1440,50 @@ def forward(
14391440
softmax_scale,
14401441
)
14411442
elif FA_VERSION == 3:
1442-
ctx.save_for_backward(
1443-
q_init,
1444-
kv_init,
1445-
attn_out,
1446-
softmax_lse,
1447-
q_ln_weight,
1448-
kv_ln_weight,
1449-
q_up_weight,
1450-
kv_up_weight,
1451-
rotary_emb,
1452-
num_heads,
1453-
q_head_dim,
1454-
qk_nope_head_dim,
1455-
v_head_dim,
1456-
qk_rope_head_dim,
1457-
position_ids,
1458-
eps,
1459-
kv_lora_rank,
1460-
softmax_scale,
1461-
)
1443+
if recompute_fa3:
1444+
ctx.save_for_backward(
1445+
q_init,
1446+
kv_init,
1447+
None,
1448+
None,
1449+
q_ln_weight,
1450+
kv_ln_weight,
1451+
q_up_weight,
1452+
kv_up_weight,
1453+
rotary_emb,
1454+
num_heads,
1455+
q_head_dim,
1456+
qk_nope_head_dim,
1457+
v_head_dim,
1458+
qk_rope_head_dim,
1459+
position_ids,
1460+
eps,
1461+
kv_lora_rank,
1462+
softmax_scale,
1463+
recompute_fa3,
1464+
)
1465+
else:
1466+
ctx.save_for_backward(
1467+
q_init,
1468+
kv_init,
1469+
attn_out,
1470+
softmax_lse,
1471+
q_ln_weight,
1472+
kv_ln_weight,
1473+
q_up_weight,
1474+
kv_up_weight,
1475+
rotary_emb,
1476+
num_heads,
1477+
q_head_dim,
1478+
qk_nope_head_dim,
1479+
v_head_dim,
1480+
qk_rope_head_dim,
1481+
position_ids,
1482+
eps,
1483+
kv_lora_rank,
1484+
softmax_scale,
1485+
recompute_fa3,
1486+
)
14621487
else:
14631488
assert False, f"invalid {FA_VERSION=}"
14641489

@@ -1508,10 +1533,17 @@ def backward(ctx, dout):
15081533
eps,
15091534
kv_lora_rank,
15101535
softmax_scale,
1536+
recompute_fa3,
15111537
) = ctx.saved_tensor()
15121538
else:
15131539
assert False, f"invalid {FA_VERSION=}"
15141540

1541+
if FA_VERSION == 2:
1542+
assert not recompute_fa3
1543+
assert attn_out is not None and softmax_lse is not None
1544+
if FA_VERSION == 3 and not recompute_fa3:
1545+
assert attn_out is not None and softmax_lse is not None
1546+
15151547
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)
15161548

15171549
q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
@@ -1591,6 +1623,27 @@ def backward(ctx, dout):
15911623
v_grad = v_grad[..., :v_head_dim]
15921624
q_grad = q_grad * softmax_scale
15931625
elif FA_VERSION == 3:
1626+
# recompute fa3
1627+
if recompute_fa3:
1628+
logger.info("Enable fa3 recomputation")
1629+
attn_out, softmax_lse = _C_ops.flash_attn_v3(
1630+
query_states,
1631+
key_states,
1632+
value_states,
1633+
None, # q_v_
1634+
None, # q_descale_
1635+
None, # k_descale_
1636+
None, # v_descale_
1637+
softmax_scale,
1638+
True,
1639+
-1, # window_size_left
1640+
-1, # window_size_right
1641+
0.0, # softcap
1642+
1, # num_splits
1643+
False, # manual_set_pack_gqa
1644+
False, # pack_gqa_
1645+
0, # sm_margin
1646+
)
15941647
with paddle.no_grad():
15951648
q_grad, k_grad, v_grad = _C_ops.flash_attn_v3_grad(
15961649
query_states,
@@ -1728,6 +1781,7 @@ def __init__(
17281781
eps,
17291782
kv_lora_rank,
17301783
softmax_scale,
1784+
recompute_fa3=False,
17311785
) -> None:
17321786
super().__init__()
17331787
self._dtype = self._helper.get_default_dtype()
@@ -1764,6 +1818,7 @@ def __init__(
17641818
self.eps,
17651819
self.kv_lora_rank,
17661820
self.softmax_scale,
1821+
self.recompute_fa3,
17671822
) = (
17681823
rotary_emb,
17691824
num_heads,
@@ -1774,6 +1829,7 @@ def __init__(
17741829
eps,
17751830
kv_lora_rank,
17761831
softmax_scale,
1832+
recompute_fa3,
17771833
)
17781834
set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn")
17791835

@@ -1805,6 +1861,7 @@ def forward(self, q_init, kv_init, position_ids):
18051861
self.eps,
18061862
self.kv_lora_rank,
18071863
self.softmax_scale,
1864+
recompute_fa3=self.recompute_fa3,
18081865
)
18091866

18101867

@@ -1962,7 +2019,7 @@ def forward(self, x):
19622019
class DeepseekV2Attention(nn.Layer):
19632020
"""Multi-headed attention from 'Attention Is All You Need' paper"""
19642021

1965-
def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
2022+
def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False, recompute_fa3: bool = False):
19662023
super().__init__()
19672024
self.config = config
19682025
self.attention_dropout = config.attention_dropout
@@ -1987,6 +2044,8 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
19872044
self.seq_length = config.seq_length
19882045
self.sequence_parallel = config.sequence_parallel
19892046

2047+
self.recompute_fa3 = recompute_fa3
2048+
19902049
self.input_layernorm = DeepseekV2RMSNorm(config)
19912050

19922051
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -2038,7 +2097,7 @@ def linear_dtype_gaurd():
20382097
if DSV3_USE_ATTEN_RECOMPUTE:
20392098
self.fused_rms_norm_linear = FusedRMSLinear(self.hidden_size, config.q_lora_rank, config.kv_lora_rank + config.qk_rope_head_dim, 1e-6)
20402099
kv_up_dim = self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim)
2041-
self.memory_recompute_att = MemroyRecomputeAttn(config.q_lora_rank, config.kv_lora_rank, config.q_lora_rank, self.num_heads * self.q_head_dim, config.kv_lora_rank, kv_up_dim, self.rotary_emb, self.num_heads, self.q_head_dim, self.qk_nope_head_dim, self.v_head_dim, self.qk_rope_head_dim, 1e-6, self.kv_lora_rank, self.softmax_scale)
2100+
self.memory_recompute_att = MemroyRecomputeAttn(config.q_lora_rank, config.kv_lora_rank, config.q_lora_rank, self.num_heads * self.q_head_dim, config.kv_lora_rank, kv_up_dim, self.rotary_emb, self.num_heads, self.q_head_dim, self.qk_nope_head_dim, self.v_head_dim, self.qk_rope_head_dim, 1e-6, self.kv_lora_rank, self.softmax_scale, recompute_fa3=self.recompute_fa3)
20422101
self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
20432102
else:
20442103

@@ -2263,7 +2322,9 @@ def forward(
22632322

22642323

22652324
class DeepseekV2DecoderLayer(nn.Layer):
2266-
def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False):
2325+
def __init__(
2326+
self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False, recompute_fa3: bool = False
2327+
):
22672328
super().__init__()
22682329
self.config = config
22692330
self.layer_idx = layer_idx
@@ -2274,7 +2335,9 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
22742335

22752336
self.hidden_size = config.hidden_size
22762337

2277-
self.self_attn = DeepseekV2Attention(config=config, layerwise_recompute=layerwise_recompute)
2338+
self.self_attn = DeepseekV2Attention(
2339+
config=config, layerwise_recompute=layerwise_recompute, recompute_fa3=recompute_fa3
2340+
)
22782341

22792342
DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
22802343

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,27 @@ def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, reco
20232023
ret.append(recompute_fwd_gate_up_list[i] + k)
20242024
return ret
20252025

2026+
def compute_recompute_fa3_list(pp_nums, all_dl_nums, recompute_fa3):
2027+
all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp
2028+
segment_size = all_layers_nums // pp_nums
2029+
recompute_fa3_list = [0]
2030+
for idx in range(segment_size - 1, all_dl_nums, segment_size):
2031+
recompute_fa3_list.append(idx)
2032+
2033+
# If `recompute_fa3` is a Boolean value and is True, means all O1 will be recomputed.
2034+
# Otherwise `recompute_fa3` should be an integer representing how many O1 are recomputed.
2035+
assert isinstance(recompute_fa3, (int, bool))
2036+
if type(recompute_fa3) is bool:
2037+
enable_k_o1_rc = segment_size if recompute_fa3 is True else 0
2038+
else:
2039+
enable_k_o1_rc = recompute_fa3
2040+
2041+
ret = []
2042+
for i in range(len(recompute_fa3_list)):
2043+
for k in range(min(segment_size, enable_k_o1_rc)):
2044+
ret.append(recompute_fa3_list[i] + k)
2045+
return ret
2046+
20262047
pp_nums = (
20272048
self.config["pipeline_parallel_degree"] * 2
20282049
if self.config.use_dualpipev
@@ -2034,7 +2055,11 @@ def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, reco
20342055
self.config.first_k_dense_replace,
20352056
self.config.recompute_fwd_gate_up,
20362057
)
2058+
recompute_fa3_list = compute_recompute_fa3_list(
2059+
pp_nums, self.config.num_hidden_layers, self.config.recompute_fa3
2060+
)
20372061

2062+
logger.info(f"recompute_fa3_list: {recompute_fa3_list}")
20382063
logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}")
20392064
config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list
20402065

@@ -2045,6 +2070,7 @@ def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, reco
20452070
config=config,
20462071
layer_idx=i,
20472072
layerwise_recompute=i not in self.no_recompute_layers,
2073+
recompute_fa3=i in recompute_fa3_list,
20482074
),
20492075
f"{self._base_model.base_model_prefix}.layers.{i}",
20502076
)

0 commit comments

Comments
 (0)