82
82
83
83
from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import WeightGradStore
84
84
85
- from ..fp8_utils import FP8KeepXLinear , FP8Linear , FP8Mlp , FP8LinearFunctionBase , cache_fp8_weight
85
+ from ..fp8_utils import (
86
+ FP8KeepXLinear ,
87
+ FP8Linear ,
88
+ FP8LinearFunctionBase ,
89
+ FP8Mlp ,
90
+ cache_fp8_weight ,
91
+ )
86
92
from .fp8_linear import Linear
87
93
88
94
DSV3_USE_FP8_GEMM = os .getenv ("DSV3_USE_FP8_GEMM" , "False" ).lower () == "true"
@@ -961,9 +967,10 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
961
967
using_post_norm_recompute = self .using_post_norm_recompute ,
962
968
)
963
969
964
- # moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
965
- # for p in self.experts.parameters():
966
- # setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
970
+ if config .offline_quant_expert_weight and config .clear_origin_weight_when_offline_quant :
971
+ moe_grad_group = fleet .get_hybrid_communicate_group ().expert_grad_comm_group
972
+ for p in self .experts .parameters ():
973
+ setattr (p , "color" , {"color" : "moe_expert" , "group" : moe_grad_group })
967
974
968
975
self .alpha = config .aux_loss_alpha
969
976
if config .n_shared_experts is not None :
@@ -995,7 +1002,7 @@ def quantize_weights(weight_list, weight_obj=None):
995
1002
"""Helper function to quantize a list of weights."""
996
1003
if weight_obj is None :
997
1004
weight_obj = weight_list [0 ]
998
- if hasattr ( weight_obj , "fp8_weight_stacked" ):
1005
+ if hasattr (weight_obj , "fp8_weight_stacked" ):
999
1006
return
1000
1007
1001
1008
# Quantize without transpose
@@ -1027,7 +1034,7 @@ def quantize_weights(weight_list, weight_obj=None):
1027
1034
if expert is not None :
1028
1035
quantize_weights ([expert .w1 ])
1029
1036
quantize_weights ([expert .w1 ])
1030
-
1037
+
1031
1038
if self .config .n_shared_experts is not None :
1032
1039
self .shared_experts .fp8_quant_weight ()
1033
1040
@@ -1194,18 +1201,22 @@ def forward(
1194
1201
1195
1202
bsz = q_init .shape [0 ]
1196
1203
q_ln_t , q_ln_invar = fused_ln .fused_rms_norm (q_init , q_ln_weight , eps )
1197
- #q = paddle.matmul(q_ln_t, q_up_weight)
1204
+ # q = paddle.matmul(q_ln_t, q_up_weight)
1198
1205
q_orig_shape = q_ln_t .shape
1199
- q = FP8LinearFunctionBase .compute_fp8_linear (q_ln_t .reshape ([- 1 , q_orig_shape [- 1 ]]), q_up_weight , weight_transpose = True , return_transpose_only = True )
1200
- q = q .reshape ( q_orig_shape [:- 1 ] + [q_up_weight .shape [- 1 ]])
1206
+ q = FP8LinearFunctionBase .compute_fp8_linear (
1207
+ q_ln_t .reshape ([- 1 , q_orig_shape [- 1 ]]), q_up_weight , weight_transpose = True , return_transpose_only = True
1208
+ )
1209
+ q = q .reshape (q_orig_shape [:- 1 ] + [q_up_weight .shape [- 1 ]])
1201
1210
1202
1211
compressed_kv , k_pe = paddle .split (kv_init , [kv_lora_rank , qk_rope_head_dim ], axis = - 1 )
1203
1212
1204
1213
kv_ln_t , kv_ln_invar = fused_ln .fused_rms_norm (compressed_kv , kv_ln_weight , eps )
1205
- #kv = paddle.matmul(kv_ln_t, kv_up_weight)
1214
+ # kv = paddle.matmul(kv_ln_t, kv_up_weight)
1206
1215
kv_orig_shape = kv_ln_t .shape
1207
- kv = FP8LinearFunctionBase .compute_fp8_linear (kv_ln_t .reshape ([- 1 , kv_orig_shape [- 1 ]]), kv_up_weight , weight_transpose = True , return_transpose_only = True )
1208
- kv = kv .reshape ( kv_orig_shape [:- 1 ] + [kv_up_weight .shape [- 1 ]])
1216
+ kv = FP8LinearFunctionBase .compute_fp8_linear (
1217
+ kv_ln_t .reshape ([- 1 , kv_orig_shape [- 1 ]]), kv_up_weight , weight_transpose = True , return_transpose_only = True
1218
+ )
1219
+ kv = kv .reshape (kv_orig_shape [:- 1 ] + [kv_up_weight .shape [- 1 ]])
1209
1220
1210
1221
query_states , key_states , value_states = qkv_pre_process (
1211
1222
q ,
@@ -1366,25 +1377,34 @@ def backward(ctx, dout):
1366
1377
1367
1378
q_ln_t , q_ln_invar = fused_ln .fused_rms_norm (q_init , q_ln_weight , eps )
1368
1379
1369
-
1370
1380
q_ln_fp8 , q_ln_scale , q_ln_trans_fp8 , q_ln_trans_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1371
- q_ln_t .reshape ([- 1 , q_ln_t .shape [- 1 ]]), output_scale_transpose = True ,
1372
- quant_method = "1x128" , input_transpose = True )
1373
-
1381
+ q_ln_t .reshape ([- 1 , q_ln_t .shape [- 1 ]]),
1382
+ output_scale_transpose = True ,
1383
+ quant_method = "1x128" ,
1384
+ input_transpose = True ,
1385
+ )
1386
+
1374
1387
q_orig_shape = q_ln_t .shape
1375
- q = FP8LinearFunctionBase .compute_fp8_linear ((q_ln_fp8 , q_ln_scale ), q_up_weight , weight_transpose = True , return_transpose_only = True )
1376
- q = q .reshape ( q_orig_shape [:- 1 ] + [q_up_weight .shape [- 1 ]])
1388
+ q = FP8LinearFunctionBase .compute_fp8_linear (
1389
+ (q_ln_fp8 , q_ln_scale ), q_up_weight , weight_transpose = True , return_transpose_only = True
1390
+ )
1391
+ q = q .reshape (q_orig_shape [:- 1 ] + [q_up_weight .shape [- 1 ]])
1377
1392
1378
1393
compressed_kv , k_pe = paddle .split (kv_init , [kv_lora_rank , qk_rope_head_dim ], axis = - 1 )
1379
1394
1380
1395
kv_ln_t , kv_ln_invar = fused_ln .fused_rms_norm (compressed_kv , kv_ln_weight , eps )
1381
-
1396
+
1382
1397
kv_ln_fp8 , kv_ln_scale , kv_ln_trans_fp8 , kv_ln_trans_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1383
- kv_ln_t .reshape ([- 1 , kv_ln_t .shape [- 1 ]]), output_scale_transpose = True ,
1384
- quant_method = "1x128" , input_transpose = True )
1398
+ kv_ln_t .reshape ([- 1 , kv_ln_t .shape [- 1 ]]),
1399
+ output_scale_transpose = True ,
1400
+ quant_method = "1x128" ,
1401
+ input_transpose = True ,
1402
+ )
1385
1403
kv_orig_shape = kv_ln_t .shape
1386
- kv = FP8LinearFunctionBase .compute_fp8_linear ((kv_ln_fp8 , kv_ln_scale ), kv_up_weight , weight_transpose = True , return_transpose_only = True )
1387
- kv = kv .reshape ( kv_orig_shape [:- 1 ] + [kv_up_weight .shape [- 1 ]])
1404
+ kv = FP8LinearFunctionBase .compute_fp8_linear (
1405
+ (kv_ln_fp8 , kv_ln_scale ), kv_up_weight , weight_transpose = True , return_transpose_only = True
1406
+ )
1407
+ kv = kv .reshape (kv_orig_shape [:- 1 ] + [kv_up_weight .shape [- 1 ]])
1388
1408
1389
1409
paddle .base .core ._set_has_grad (True )
1390
1410
q .stop_gradient = False
@@ -1465,11 +1485,16 @@ def backward(ctx, dout):
1465
1485
# call up proj
1466
1486
if hasattr (kv_up_weight , "main_grad" ):
1467
1487
d_kv_fp8 , d_kv_scale , d_kv_t_fp8 , d_kv_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1468
- d_kv .reshape ([- 1 , d_kv .shape [- 1 ]]), output_scale_transpose = True ,
1469
- quant_method = "1x128" , input_transpose = True )
1488
+ d_kv .reshape ([- 1 , d_kv .shape [- 1 ]]),
1489
+ output_scale_transpose = True ,
1490
+ quant_method = "1x128" ,
1491
+ input_transpose = True ,
1492
+ )
1470
1493
1471
- d_kv_ln_t = FP8LinearFunctionBase .compute_fp8_linear ((d_kv_fp8 , d_kv_scale ), kv_up_weight , weight_transpose = False )
1472
- d_kv_ln_t = d_kv_ln_t .reshape ( d_kv .shape [:- 1 ] + [kv_up_weight .shape [0 ]])
1494
+ d_kv_ln_t = FP8LinearFunctionBase .compute_fp8_linear (
1495
+ (d_kv_fp8 , d_kv_scale ), kv_up_weight , weight_transpose = False
1496
+ )
1497
+ d_kv_ln_t = d_kv_ln_t .reshape (d_kv .shape [:- 1 ] + [kv_up_weight .shape [0 ]])
1473
1498
1474
1499
def kv_up_weight_grad (kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight ):
1475
1500
FP8LinearFunctionBase .kitchen_gemm (
@@ -1480,11 +1505,16 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
1480
1505
True ,
1481
1506
True ,
1482
1507
kv_up_weight .main_grad ,
1483
- paddle .float32 )
1484
-
1508
+ paddle .float32 ,
1509
+ )
1510
+
1485
1511
if WeightGradStore .enabled :
1486
-
1487
- WeightGradStore .put (partial (kv_up_weight_grad , kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight ))
1512
+
1513
+ WeightGradStore .put (
1514
+ partial (
1515
+ kv_up_weight_grad , kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight
1516
+ )
1517
+ )
1488
1518
else :
1489
1519
kv_up_weight_grad (kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight )
1490
1520
@@ -1493,7 +1523,6 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
1493
1523
else :
1494
1524
d_kv_ln_t , d_kv_up_weight = _C_ops .matmul_grad (kv_ln_t , kv_up_weight , d_kv , False , False )
1495
1525
1496
-
1497
1526
d_compressed_kv , d_kv_ln_weight = fused_ln .fused_rms_norm_grad_func (
1498
1527
compressed_kv , kv_ln_weight , kv_ln_invar , d_kv_ln_t , eps
1499
1528
)
@@ -1503,15 +1532,19 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
1503
1532
if hasattr (q_up_weight , "main_grad" ):
1504
1533
1505
1534
d_q_fp8 , d_q_scale , d_q_t_fp8 , d_q_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1506
- d_q .reshape ([- 1 , d_q .shape [- 1 ]]), output_scale_transpose = True ,
1507
- quant_method = "1x128" , input_transpose = True )
1508
- #d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
1535
+ d_q .reshape ([- 1 , d_q .shape [- 1 ]]),
1536
+ output_scale_transpose = True ,
1537
+ quant_method = "1x128" ,
1538
+ input_transpose = True ,
1539
+ )
1540
+ # d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
1509
1541
1510
- d_q_ln_t = FP8LinearFunctionBase .compute_fp8_linear ((d_q_fp8 , d_q_scale ), q_up_weight , weight_transpose = False )
1511
- d_q_ln_t = d_q_ln_t .reshape ( d_q .shape [:- 1 ] + [q_up_weight .shape [0 ]])
1542
+ d_q_ln_t = FP8LinearFunctionBase .compute_fp8_linear (
1543
+ (d_q_fp8 , d_q_scale ), q_up_weight , weight_transpose = False
1544
+ )
1545
+ d_q_ln_t = d_q_ln_t .reshape (d_q .shape [:- 1 ] + [q_up_weight .shape [0 ]])
1512
1546
1513
-
1514
- def q_up_weight_grad (q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight ):
1547
+ def q_up_weight_grad (q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight ):
1515
1548
FP8LinearFunctionBase .kitchen_gemm (
1516
1549
q_ln_trans_fp8 ,
1517
1550
q_ln_trans_scale ,
@@ -1520,11 +1553,13 @@ def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q
1520
1553
True ,
1521
1554
True ,
1522
1555
q_up_weight .main_grad ,
1523
- paddle .float32 )
1524
-
1556
+ paddle .float32 ,
1557
+ )
1525
1558
1526
- if WeightGradStore .enabled :
1527
- WeightGradStore .put (partial (q_up_weight_grad , q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight ))
1559
+ if WeightGradStore .enabled :
1560
+ WeightGradStore .put (
1561
+ partial (q_up_weight_grad , q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight )
1562
+ )
1528
1563
else :
1529
1564
q_up_weight_grad (q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight )
1530
1565
@@ -1605,17 +1640,16 @@ def __init__(
1605
1640
)
1606
1641
1607
1642
def fp8_quant_weight (self ):
1608
- cache_fp8_weight ( self .q_up_weight )
1609
- cache_fp8_weight ( self .kv_up_weight )
1643
+ cache_fp8_weight (self .q_up_weight )
1644
+ cache_fp8_weight (self .kv_up_weight )
1610
1645
1611
1646
def forward (self , q_init , kv_init , position_ids ):
1612
-
1647
+
1613
1648
seq_len = q_init .shape [1 ]
1614
1649
1615
1650
if self .rotary_emb .max_seq_len_cached is None or seq_len > self .rotary_emb .max_seq_len_cached :
1616
1651
self .rotary_emb ._set_cos_sin_cache (seq_len )
1617
1652
1618
-
1619
1653
return MemroyRecomputeAttnFunc .apply (
1620
1654
q_init ,
1621
1655
kv_init ,
@@ -1641,18 +1675,19 @@ class FusedRMSLinearFunc(paddle.autograd.PyLayer):
1641
1675
def forward (ctx , x , rms_norm_weight , q_down_weight , kv_down_weight , eps ):
1642
1676
1643
1677
hidden_states , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
1644
-
1678
+
1645
1679
h_fp8 , h_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1646
- hidden_states .reshape ([- 1 , hidden_states .shape [- 1 ]]), output_scale_transpose = True ,
1647
- quant_method = "1x128" )
1680
+ hidden_states .reshape ([- 1 , hidden_states .shape [- 1 ]]), output_scale_transpose = True , quant_method = "1x128"
1681
+ )
1648
1682
1649
1683
h_orig_shape = hidden_states .shape
1650
- q = FP8LinearFunctionBase .compute_fp8_linear ((h_fp8 , h_scale ), q_down_weight , weight_transpose = True , return_transpose_only = True )
1651
- q = q .reshape ( h_orig_shape [:- 1 ] + [q_down_weight .shape [- 1 ]])
1652
-
1684
+ q = FP8LinearFunctionBase .compute_fp8_linear (
1685
+ (h_fp8 , h_scale ), q_down_weight , weight_transpose = True , return_transpose_only = True
1686
+ )
1687
+ q = q .reshape (h_orig_shape [:- 1 ] + [q_down_weight .shape [- 1 ]])
1653
1688
1654
1689
kv = paddle .matmul (hidden_states , kv_down_weight )
1655
-
1690
+
1656
1691
ctx .save_for_backward (x , rms_norm_weight , q_down_weight , kv_down_weight )
1657
1692
ctx .eps = eps
1658
1693
return q , kv
@@ -1662,35 +1697,39 @@ def backward(ctx, d_q, d_kv):
1662
1697
x , rms_norm_weight , q_down_weight , kv_down_weight = ctx .saved_tensor ()
1663
1698
eps = ctx .eps
1664
1699
hidden_states , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
1665
-
1700
+
1666
1701
h_t_fp8 , h_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1667
- hidden_states .reshape ([- 1 , hidden_states .shape [- 1 ]]), output_scale_transpose = True ,
1668
- quant_method = "1x128" , input_transpose = True , return_transpose_only = True )
1702
+ hidden_states .reshape ([- 1 , hidden_states .shape [- 1 ]]),
1703
+ output_scale_transpose = True ,
1704
+ quant_method = "1x128" ,
1705
+ input_transpose = True ,
1706
+ return_transpose_only = True ,
1707
+ )
1669
1708
1670
1709
h_grad , d_kv_down_weight = _C_ops .matmul_grad (hidden_states , kv_down_weight , d_kv , False , False )
1671
1710
1672
1711
if hasattr (q_down_weight , "main_grad" ):
1673
1712
d_q_fp8 , d_q_scale , d_q_t_fp8 , d_q_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1674
- d_q .reshape ([- 1 , d_q .shape [- 1 ]]), output_scale_transpose = True ,
1675
- quant_method = "1x128" , input_transpose = True )
1676
- FP8LinearFunctionBase .compute_fp8_linear ((d_q_fp8 , d_q_scale ), q_down_weight , weight_transpose = False , out = h_grad .view ( [- 1 , h_grad .shape [- 1 ]]))
1677
-
1713
+ d_q .reshape ([- 1 , d_q .shape [- 1 ]]),
1714
+ output_scale_transpose = True ,
1715
+ quant_method = "1x128" ,
1716
+ input_transpose = True ,
1717
+ )
1718
+ FP8LinearFunctionBase .compute_fp8_linear (
1719
+ (d_q_fp8 , d_q_scale ), q_down_weight , weight_transpose = False , out = h_grad .view ([- 1 , h_grad .shape [- 1 ]])
1720
+ )
1678
1721
1679
- def q_down_weight_grad (h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight ):
1722
+ def q_down_weight_grad (h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight ):
1680
1723
FP8LinearFunctionBase .kitchen_gemm (
1681
- h_t_fp8 ,
1682
- h_t_scale ,
1683
- d_q_t_fp8 ,
1684
- d_q_t_scale ,
1685
- True ,
1686
- True ,
1687
- q_down_weight .main_grad ,
1688
- paddle .float32 )
1689
-
1690
- if WeightGradStore .enabled :
1691
- WeightGradStore .put (partial (q_down_weight_grad , h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight ))
1724
+ h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , True , True , q_down_weight .main_grad , paddle .float32
1725
+ )
1726
+
1727
+ if WeightGradStore .enabled :
1728
+ WeightGradStore .put (
1729
+ partial (q_down_weight_grad , h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight )
1730
+ )
1692
1731
else :
1693
- q_down_weight_grad ( h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight )
1732
+ q_down_weight_grad (h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight )
1694
1733
1695
1734
d_q_down_weight = None
1696
1735
@@ -1726,10 +1765,9 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
1726
1765
is_bias = False ,
1727
1766
)
1728
1767
self .eps = eps
1729
-
1730
- def fp8_quant_weight (self ):
1731
- cache_fp8_weight ( self .q_down_weight )
1732
-
1768
+
1769
+ def fp8_quant_weight (self ):
1770
+ cache_fp8_weight (self .q_down_weight )
1733
1771
1734
1772
def forward (self , x ):
1735
1773
@@ -1898,8 +1936,6 @@ def fp8_quant_weight(self):
1898
1936
self .memory_recompute_att .fp8_quant_weight ()
1899
1937
self .fused_rms_norm_linear .fp8_quant_weight ()
1900
1938
1901
-
1902
-
1903
1939
def _init_rope (self ):
1904
1940
if self .config .rope_scaling is None :
1905
1941
self .rotary_emb = DeepseekV2RotaryEmbedding (
0 commit comments