Skip to content

Commit 542c1b5

Browse files
authored
Add config control for quant cache (#10926)
* add config * refine
1 parent b93fda9 commit 542c1b5

File tree

3 files changed

+121
-80
lines changed

3 files changed

+121
-80
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def __init__(
186186
is_split_group_gemm=False,
187187
fakse_gate_restrict_balance=False,
188188
adaptive_remained_O1_recompute_ratio=0,
189+
offline_quant_expert_weight=True,
190+
clear_origin_weight_when_offline_quant=True,
189191
**kwargs,
190192
):
191193
self.vocab_size = vocab_size
@@ -241,6 +243,8 @@ def __init__(
241243
self.is_split_group_gemm = is_split_group_gemm
242244
self.fakse_gate_restrict_balance = fakse_gate_restrict_balance
243245
self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio
246+
self.offline_quant_expert_weight = offline_quant_expert_weight
247+
self.clear_origin_weight_when_offline_quant = clear_origin_weight_when_offline_quant
244248

245249
super().__init__(
246250
pad_token_id=pad_token_id,

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 116 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@
8282

8383
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore
8484

85-
from ..fp8_utils import FP8KeepXLinear, FP8Linear, FP8Mlp, FP8LinearFunctionBase, cache_fp8_weight
85+
from ..fp8_utils import (
86+
FP8KeepXLinear,
87+
FP8Linear,
88+
FP8LinearFunctionBase,
89+
FP8Mlp,
90+
cache_fp8_weight,
91+
)
8692
from .fp8_linear import Linear
8793

8894
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
@@ -961,9 +967,10 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
961967
using_post_norm_recompute=self.using_post_norm_recompute,
962968
)
963969

964-
# moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
965-
# for p in self.experts.parameters():
966-
# setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
970+
if config.offline_quant_expert_weight and config.clear_origin_weight_when_offline_quant:
971+
moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
972+
for p in self.experts.parameters():
973+
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
967974

968975
self.alpha = config.aux_loss_alpha
969976
if config.n_shared_experts is not None:
@@ -995,7 +1002,7 @@ def quantize_weights(weight_list, weight_obj=None):
9951002
"""Helper function to quantize a list of weights."""
9961003
if weight_obj is None:
9971004
weight_obj = weight_list[0]
998-
if hasattr( weight_obj, "fp8_weight_stacked"):
1005+
if hasattr(weight_obj, "fp8_weight_stacked"):
9991006
return
10001007

10011008
# Quantize without transpose
@@ -1027,7 +1034,7 @@ def quantize_weights(weight_list, weight_obj=None):
10271034
if expert is not None:
10281035
quantize_weights([expert.w1])
10291036
quantize_weights([expert.w1])
1030-
1037+
10311038
if self.config.n_shared_experts is not None:
10321039
self.shared_experts.fp8_quant_weight()
10331040

@@ -1194,18 +1201,22 @@ def forward(
11941201

11951202
bsz = q_init.shape[0]
11961203
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)
1197-
#q = paddle.matmul(q_ln_t, q_up_weight)
1204+
# q = paddle.matmul(q_ln_t, q_up_weight)
11981205
q_orig_shape = q_ln_t.shape
1199-
q = FP8LinearFunctionBase.compute_fp8_linear(q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True)
1200-
q = q.reshape( q_orig_shape[:-1] + [q_up_weight.shape[-1]])
1206+
q = FP8LinearFunctionBase.compute_fp8_linear(
1207+
q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True
1208+
)
1209+
q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]])
12011210

12021211
compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1)
12031212

12041213
kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps)
1205-
#kv = paddle.matmul(kv_ln_t, kv_up_weight)
1214+
# kv = paddle.matmul(kv_ln_t, kv_up_weight)
12061215
kv_orig_shape = kv_ln_t.shape
1207-
kv = FP8LinearFunctionBase.compute_fp8_linear(kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True)
1208-
kv = kv.reshape( kv_orig_shape[:-1] + [kv_up_weight.shape[-1]])
1216+
kv = FP8LinearFunctionBase.compute_fp8_linear(
1217+
kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True
1218+
)
1219+
kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]])
12091220

12101221
query_states, key_states, value_states = qkv_pre_process(
12111222
q,
@@ -1366,25 +1377,34 @@ def backward(ctx, dout):
13661377

13671378
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)
13681379

1369-
13701380
q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1371-
q_ln_t.reshape([-1, q_ln_t.shape[-1]]), output_scale_transpose=True,
1372-
quant_method="1x128", input_transpose=True )
1373-
1381+
q_ln_t.reshape([-1, q_ln_t.shape[-1]]),
1382+
output_scale_transpose=True,
1383+
quant_method="1x128",
1384+
input_transpose=True,
1385+
)
1386+
13741387
q_orig_shape = q_ln_t.shape
1375-
q = FP8LinearFunctionBase.compute_fp8_linear((q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True)
1376-
q = q.reshape( q_orig_shape[:-1] + [q_up_weight.shape[-1]])
1388+
q = FP8LinearFunctionBase.compute_fp8_linear(
1389+
(q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True
1390+
)
1391+
q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]])
13771392

13781393
compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1)
13791394

13801395
kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps)
1381-
1396+
13821397
kv_ln_fp8, kv_ln_scale, kv_ln_trans_fp8, kv_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1383-
kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), output_scale_transpose=True,
1384-
quant_method="1x128", input_transpose=True )
1398+
kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]),
1399+
output_scale_transpose=True,
1400+
quant_method="1x128",
1401+
input_transpose=True,
1402+
)
13851403
kv_orig_shape = kv_ln_t.shape
1386-
kv = FP8LinearFunctionBase.compute_fp8_linear((kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True)
1387-
kv = kv.reshape( kv_orig_shape[:-1] + [kv_up_weight.shape[-1]])
1404+
kv = FP8LinearFunctionBase.compute_fp8_linear(
1405+
(kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True
1406+
)
1407+
kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]])
13881408

13891409
paddle.base.core._set_has_grad(True)
13901410
q.stop_gradient = False
@@ -1465,11 +1485,16 @@ def backward(ctx, dout):
14651485
# call up proj
14661486
if hasattr(kv_up_weight, "main_grad"):
14671487
d_kv_fp8, d_kv_scale, d_kv_t_fp8, d_kv_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1468-
d_kv.reshape([-1, d_kv.shape[-1]]), output_scale_transpose=True,
1469-
quant_method="1x128", input_transpose=True )
1488+
d_kv.reshape([-1, d_kv.shape[-1]]),
1489+
output_scale_transpose=True,
1490+
quant_method="1x128",
1491+
input_transpose=True,
1492+
)
14701493

1471-
d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear((d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False)
1472-
d_kv_ln_t = d_kv_ln_t.reshape( d_kv.shape[:-1] + [kv_up_weight.shape[0]])
1494+
d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear(
1495+
(d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False
1496+
)
1497+
d_kv_ln_t = d_kv_ln_t.reshape(d_kv.shape[:-1] + [kv_up_weight.shape[0]])
14731498

14741499
def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight):
14751500
FP8LinearFunctionBase.kitchen_gemm(
@@ -1480,11 +1505,16 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
14801505
True,
14811506
True,
14821507
kv_up_weight.main_grad,
1483-
paddle.float32 )
1484-
1508+
paddle.float32,
1509+
)
1510+
14851511
if WeightGradStore.enabled:
1486-
1487-
WeightGradStore.put(partial(kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight))
1512+
1513+
WeightGradStore.put(
1514+
partial(
1515+
kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight
1516+
)
1517+
)
14881518
else:
14891519
kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight)
14901520

@@ -1493,7 +1523,6 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
14931523
else:
14941524
d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False)
14951525

1496-
14971526
d_compressed_kv, d_kv_ln_weight = fused_ln.fused_rms_norm_grad_func(
14981527
compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps
14991528
)
@@ -1503,15 +1532,19 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
15031532
if hasattr(q_up_weight, "main_grad"):
15041533

15051534
d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1506-
d_q.reshape([-1, d_q.shape[-1]]), output_scale_transpose=True,
1507-
quant_method="1x128", input_transpose=True )
1508-
#d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
1535+
d_q.reshape([-1, d_q.shape[-1]]),
1536+
output_scale_transpose=True,
1537+
quant_method="1x128",
1538+
input_transpose=True,
1539+
)
1540+
# d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
15091541

1510-
d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear((d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False)
1511-
d_q_ln_t = d_q_ln_t.reshape( d_q.shape[:-1] + [q_up_weight.shape[0]])
1542+
d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear(
1543+
(d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False
1544+
)
1545+
d_q_ln_t = d_q_ln_t.reshape(d_q.shape[:-1] + [q_up_weight.shape[0]])
15121546

1513-
1514-
def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight):
1547+
def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight):
15151548
FP8LinearFunctionBase.kitchen_gemm(
15161549
q_ln_trans_fp8,
15171550
q_ln_trans_scale,
@@ -1520,11 +1553,13 @@ def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q
15201553
True,
15211554
True,
15221555
q_up_weight.main_grad,
1523-
paddle.float32 )
1524-
1556+
paddle.float32,
1557+
)
15251558

1526-
if WeightGradStore.enabled:
1527-
WeightGradStore.put(partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight))
1559+
if WeightGradStore.enabled:
1560+
WeightGradStore.put(
1561+
partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight)
1562+
)
15281563
else:
15291564
q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight)
15301565

@@ -1605,17 +1640,16 @@ def __init__(
16051640
)
16061641

16071642
def fp8_quant_weight(self):
1608-
cache_fp8_weight( self.q_up_weight)
1609-
cache_fp8_weight( self.kv_up_weight)
1643+
cache_fp8_weight(self.q_up_weight)
1644+
cache_fp8_weight(self.kv_up_weight)
16101645

16111646
def forward(self, q_init, kv_init, position_ids):
1612-
1647+
16131648
seq_len = q_init.shape[1]
16141649

16151650
if self.rotary_emb.max_seq_len_cached is None or seq_len > self.rotary_emb.max_seq_len_cached:
16161651
self.rotary_emb._set_cos_sin_cache(seq_len)
16171652

1618-
16191653
return MemroyRecomputeAttnFunc.apply(
16201654
q_init,
16211655
kv_init,
@@ -1641,18 +1675,19 @@ class FusedRMSLinearFunc(paddle.autograd.PyLayer):
16411675
def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps):
16421676

16431677
hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
1644-
1678+
16451679
h_fp8, h_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1646-
hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True,
1647-
quant_method="1x128" )
1680+
hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True, quant_method="1x128"
1681+
)
16481682

16491683
h_orig_shape = hidden_states.shape
1650-
q = FP8LinearFunctionBase.compute_fp8_linear((h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True)
1651-
q = q.reshape( h_orig_shape[:-1] + [q_down_weight.shape[-1]])
1652-
1684+
q = FP8LinearFunctionBase.compute_fp8_linear(
1685+
(h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True
1686+
)
1687+
q = q.reshape(h_orig_shape[:-1] + [q_down_weight.shape[-1]])
16531688

16541689
kv = paddle.matmul(hidden_states, kv_down_weight)
1655-
1690+
16561691
ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight)
16571692
ctx.eps = eps
16581693
return q, kv
@@ -1662,35 +1697,39 @@ def backward(ctx, d_q, d_kv):
16621697
x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor()
16631698
eps = ctx.eps
16641699
hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
1665-
1700+
16661701
h_t_fp8, h_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1667-
hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True,
1668-
quant_method="1x128", input_transpose=True, return_transpose_only=True )
1702+
hidden_states.reshape([-1, hidden_states.shape[-1]]),
1703+
output_scale_transpose=True,
1704+
quant_method="1x128",
1705+
input_transpose=True,
1706+
return_transpose_only=True,
1707+
)
16691708

16701709
h_grad, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False)
16711710

16721711
if hasattr(q_down_weight, "main_grad"):
16731712
d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1674-
d_q.reshape([-1, d_q.shape[-1]]), output_scale_transpose=True,
1675-
quant_method="1x128", input_transpose=True )
1676-
FP8LinearFunctionBase.compute_fp8_linear((d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view( [-1, h_grad.shape[-1]]))
1677-
1713+
d_q.reshape([-1, d_q.shape[-1]]),
1714+
output_scale_transpose=True,
1715+
quant_method="1x128",
1716+
input_transpose=True,
1717+
)
1718+
FP8LinearFunctionBase.compute_fp8_linear(
1719+
(d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]])
1720+
)
16781721

1679-
def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight):
1722+
def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight):
16801723
FP8LinearFunctionBase.kitchen_gemm(
1681-
h_t_fp8,
1682-
h_t_scale,
1683-
d_q_t_fp8,
1684-
d_q_t_scale,
1685-
True,
1686-
True,
1687-
q_down_weight.main_grad,
1688-
paddle.float32 )
1689-
1690-
if WeightGradStore.enabled:
1691-
WeightGradStore.put(partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight))
1724+
h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, True, True, q_down_weight.main_grad, paddle.float32
1725+
)
1726+
1727+
if WeightGradStore.enabled:
1728+
WeightGradStore.put(
1729+
partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight)
1730+
)
16921731
else:
1693-
q_down_weight_grad( h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight)
1732+
q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight)
16941733

16951734
d_q_down_weight = None
16961735

@@ -1726,10 +1765,9 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
17261765
is_bias=False,
17271766
)
17281767
self.eps = eps
1729-
1730-
def fp8_quant_weight(self):
1731-
cache_fp8_weight( self.q_down_weight)
1732-
1768+
1769+
def fp8_quant_weight(self):
1770+
cache_fp8_weight(self.q_down_weight)
17331771

17341772
def forward(self, x):
17351773

@@ -1898,8 +1936,6 @@ def fp8_quant_weight(self):
18981936
self.memory_recompute_att.fp8_quant_weight()
18991937
self.fused_rms_norm_linear.fp8_quant_weight()
19001938

1901-
1902-
19031939
def _init_rope(self):
19041940
if self.config.rope_scaling is None:
19051941
self.rotary_emb = DeepseekV2RotaryEmbedding(

paddlenlp/transformers/moe_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ def forward(self, hidden_states_out, previous_event=None, async_finish=False, al
633633
allocate_on_comm_stream=allocate_on_comm_stream,
634634
)
635635
output_combine.stop_gradient = False
636+
self.token_dispatcher._comm_manager.handle = None
636637
return output_combine
637638

638639
@paddle.no_grad()

0 commit comments

Comments
 (0)