@@ -1334,6 +1334,7 @@ def forward(
1334
1334
eps ,
1335
1335
kv_lora_rank ,
1336
1336
softmax_scale ,
1337
+ recompute_fa3 = False ,
1337
1338
):
1338
1339
1339
1340
bsz = q_init .shape [0 ]
@@ -1439,26 +1440,50 @@ def forward(
1439
1440
softmax_scale ,
1440
1441
)
1441
1442
elif FA_VERSION == 3 :
1442
- ctx .save_for_backward (
1443
- q_init ,
1444
- kv_init ,
1445
- attn_out ,
1446
- softmax_lse ,
1447
- q_ln_weight ,
1448
- kv_ln_weight ,
1449
- q_up_weight ,
1450
- kv_up_weight ,
1451
- rotary_emb ,
1452
- num_heads ,
1453
- q_head_dim ,
1454
- qk_nope_head_dim ,
1455
- v_head_dim ,
1456
- qk_rope_head_dim ,
1457
- position_ids ,
1458
- eps ,
1459
- kv_lora_rank ,
1460
- softmax_scale ,
1461
- )
1443
+ if recompute_fa3 :
1444
+ ctx .save_for_backward (
1445
+ q_init ,
1446
+ kv_init ,
1447
+ None ,
1448
+ None ,
1449
+ q_ln_weight ,
1450
+ kv_ln_weight ,
1451
+ q_up_weight ,
1452
+ kv_up_weight ,
1453
+ rotary_emb ,
1454
+ num_heads ,
1455
+ q_head_dim ,
1456
+ qk_nope_head_dim ,
1457
+ v_head_dim ,
1458
+ qk_rope_head_dim ,
1459
+ position_ids ,
1460
+ eps ,
1461
+ kv_lora_rank ,
1462
+ softmax_scale ,
1463
+ recompute_fa3 ,
1464
+ )
1465
+ else :
1466
+ ctx .save_for_backward (
1467
+ q_init ,
1468
+ kv_init ,
1469
+ attn_out ,
1470
+ softmax_lse ,
1471
+ q_ln_weight ,
1472
+ kv_ln_weight ,
1473
+ q_up_weight ,
1474
+ kv_up_weight ,
1475
+ rotary_emb ,
1476
+ num_heads ,
1477
+ q_head_dim ,
1478
+ qk_nope_head_dim ,
1479
+ v_head_dim ,
1480
+ qk_rope_head_dim ,
1481
+ position_ids ,
1482
+ eps ,
1483
+ kv_lora_rank ,
1484
+ softmax_scale ,
1485
+ recompute_fa3 ,
1486
+ )
1462
1487
else :
1463
1488
assert False , f"invalid { FA_VERSION = } "
1464
1489
@@ -1508,10 +1533,17 @@ def backward(ctx, dout):
1508
1533
eps ,
1509
1534
kv_lora_rank ,
1510
1535
softmax_scale ,
1536
+ recompute_fa3 ,
1511
1537
) = ctx .saved_tensor ()
1512
1538
else :
1513
1539
assert False , f"invalid { FA_VERSION = } "
1514
1540
1541
+ if FA_VERSION == 2 :
1542
+ assert not recompute_fa3
1543
+ assert attn_out is not None and softmax_lse is not None
1544
+ if FA_VERSION == 3 and not recompute_fa3 :
1545
+ assert attn_out is not None and softmax_lse is not None
1546
+
1515
1547
q_ln_t , q_ln_invar = fused_ln .fused_rms_norm (q_init , q_ln_weight , eps )
1516
1548
1517
1549
q_ln_fp8 , q_ln_scale , q_ln_trans_fp8 , q_ln_trans_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
@@ -1591,6 +1623,27 @@ def backward(ctx, dout):
1591
1623
v_grad = v_grad [..., :v_head_dim ]
1592
1624
q_grad = q_grad * softmax_scale
1593
1625
elif FA_VERSION == 3 :
1626
+ # recompute fa3
1627
+ if recompute_fa3 :
1628
+ logger .info ("Enable fa3 recomputation" )
1629
+ attn_out , softmax_lse = _C_ops .flash_attn_v3 (
1630
+ query_states ,
1631
+ key_states ,
1632
+ value_states ,
1633
+ None , # q_v_
1634
+ None , # q_descale_
1635
+ None , # k_descale_
1636
+ None , # v_descale_
1637
+ softmax_scale ,
1638
+ True ,
1639
+ - 1 , # window_size_left
1640
+ - 1 , # window_size_right
1641
+ 0.0 , # softcap
1642
+ 1 , # num_splits
1643
+ False , # manual_set_pack_gqa
1644
+ False , # pack_gqa_
1645
+ 0 , # sm_margin
1646
+ )
1594
1647
with paddle .no_grad ():
1595
1648
q_grad , k_grad , v_grad = _C_ops .flash_attn_v3_grad (
1596
1649
query_states ,
@@ -1728,6 +1781,7 @@ def __init__(
1728
1781
eps ,
1729
1782
kv_lora_rank ,
1730
1783
softmax_scale ,
1784
+ recompute_fa3 = False ,
1731
1785
) -> None :
1732
1786
super ().__init__ ()
1733
1787
self ._dtype = self ._helper .get_default_dtype ()
@@ -1764,6 +1818,7 @@ def __init__(
1764
1818
self .eps ,
1765
1819
self .kv_lora_rank ,
1766
1820
self .softmax_scale ,
1821
+ self .recompute_fa3 ,
1767
1822
) = (
1768
1823
rotary_emb ,
1769
1824
num_heads ,
@@ -1774,6 +1829,7 @@ def __init__(
1774
1829
eps ,
1775
1830
kv_lora_rank ,
1776
1831
softmax_scale ,
1832
+ recompute_fa3 ,
1777
1833
)
1778
1834
set_parameter_color ([self .q_up_weight , self .kv_up_weight ], "memory_attn" )
1779
1835
@@ -1805,6 +1861,7 @@ def forward(self, q_init, kv_init, position_ids):
1805
1861
self .eps ,
1806
1862
self .kv_lora_rank ,
1807
1863
self .softmax_scale ,
1864
+ recompute_fa3 = self .recompute_fa3 ,
1808
1865
)
1809
1866
1810
1867
@@ -1962,7 +2019,7 @@ def forward(self, x):
1962
2019
class DeepseekV2Attention (nn .Layer ):
1963
2020
"""Multi-headed attention from 'Attention Is All You Need' paper"""
1964
2021
1965
- def __init__ (self , config : DeepseekV2Config , layerwise_recompute : bool = False ):
2022
+ def __init__ (self , config : DeepseekV2Config , layerwise_recompute : bool = False , recompute_fa3 : bool = False ):
1966
2023
super ().__init__ ()
1967
2024
self .config = config
1968
2025
self .attention_dropout = config .attention_dropout
@@ -1987,6 +2044,8 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
1987
2044
self .seq_length = config .seq_length
1988
2045
self .sequence_parallel = config .sequence_parallel
1989
2046
2047
+ self .recompute_fa3 = recompute_fa3
2048
+
1990
2049
self .input_layernorm = DeepseekV2RMSNorm (config )
1991
2050
1992
2051
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -2038,7 +2097,7 @@ def linear_dtype_gaurd():
2038
2097
if DSV3_USE_ATTEN_RECOMPUTE :
2039
2098
self .fused_rms_norm_linear = FusedRMSLinear (self .hidden_size , config .q_lora_rank , config .kv_lora_rank + config .qk_rope_head_dim , 1e-6 )
2040
2099
kv_up_dim = self .num_heads * (self .q_head_dim - self .qk_rope_head_dim + self .v_head_dim )
2041
- self .memory_recompute_att = MemroyRecomputeAttn (config .q_lora_rank , config .kv_lora_rank , config .q_lora_rank , self .num_heads * self .q_head_dim , config .kv_lora_rank , kv_up_dim , self .rotary_emb , self .num_heads , self .q_head_dim , self .qk_nope_head_dim , self .v_head_dim , self .qk_rope_head_dim , 1e-6 , self .kv_lora_rank , self .softmax_scale )
2100
+ self .memory_recompute_att = MemroyRecomputeAttn (config .q_lora_rank , config .kv_lora_rank , config .q_lora_rank , self .num_heads * self .q_head_dim , config .kv_lora_rank , kv_up_dim , self .rotary_emb , self .num_heads , self .q_head_dim , self .qk_nope_head_dim , self .v_head_dim , self .qk_rope_head_dim , 1e-6 , self .kv_lora_rank , self .softmax_scale , recompute_fa3 = self . recompute_fa3 )
2042
2101
self .o_proj = FP8KeepXLinear (self .num_heads * self .v_head_dim , self .hidden_size , bias_attr = config .attention_bias )
2043
2102
else :
2044
2103
@@ -2263,7 +2322,9 @@ def forward(
2263
2322
2264
2323
2265
2324
class DeepseekV2DecoderLayer (nn .Layer ):
2266
- def __init__ (self , config : DeepseekV2Config , layer_idx : int , layerwise_recompute : bool = False ):
2325
+ def __init__ (
2326
+ self , config : DeepseekV2Config , layer_idx : int , layerwise_recompute : bool = False , recompute_fa3 : bool = False
2327
+ ):
2267
2328
super ().__init__ ()
2268
2329
self .config = config
2269
2330
self .layer_idx = layer_idx
@@ -2274,7 +2335,9 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
2274
2335
2275
2336
self .hidden_size = config .hidden_size
2276
2337
2277
- self .self_attn = DeepseekV2Attention (config = config , layerwise_recompute = layerwise_recompute )
2338
+ self .self_attn = DeepseekV2Attention (
2339
+ config = config , layerwise_recompute = layerwise_recompute , recompute_fa3 = recompute_fa3
2340
+ )
2278
2341
2279
2342
DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP
2280
2343
0 commit comments