Skip to content

Commit 79c421f

Browse files
Update quant cache (#11007)
* fix * update quant cache --------- Co-authored-by: zhangbo9674 <[email protected]>
1 parent 17cc8a4 commit 79c421f

File tree

3 files changed

+61
-23
lines changed

3 files changed

+61
-23
lines changed

paddlenlp/trainer/trainer_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def on_step_begin(self, args, state, control, **kwargs):
645645
global skip_count
646646

647647
if (not g_shard_bypass_dygraph_optimizer or skip_count == 0) and hasattr(model, "fp8_quant_weight"):
648-
model.fp8_quant_weight(True, quant_transpose=False)
648+
model.fp8_quant_weight(True, quant_transpose=True)
649649
optimizer.clear_param_storage("moe_expert")
650650
optimizer.clear_param_storage("rms_linear")
651651
optimizer.clear_param_storage("memory_attn")

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,35 +1073,49 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
10731073
)
10741074
set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert")
10751075

1076-
def fp8_quant_weight(self, batch_mode=False, quant_transpose=True):
1076+
def fp8_quant_weight(self, batch_mode=False, quant_transpose=None):
10771077
"""Quantize weights in FP8 format.
10781078
10791079
Args:
10801080
batch_mode: If True, quantize all weights in batch mode using the first expert's weights.
10811081
If False, quantize each expert's weights individually.
10821082
"""
10831083

1084-
def quantize_weights(weight_list, weight_obj=None, quant_transpose=True):
1084+
def quantize_weights(weight_list, weight_obj=None, quant_transpose=None):
10851085
"""Helper function to quantize a list of weights."""
10861086
if weight_obj is None:
10871087
weight_obj = weight_list[0]
1088-
if hasattr(weight_obj, "fp8_weight_stacked"):
1088+
if hasattr(weight_obj, "fp8_weight_stacked") or hasattr(weight_obj, "fp8_weight_stacked_transpose"):
10891089
return
10901090

1091-
# Quantize without transpose
1092-
fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1093-
weight_list, transpose=False
1094-
)
1095-
setattr(weight_obj, "fp8_weight_stacked", fp8_weight)
1096-
setattr(weight_obj, "fp8_scale_stacked", fp8_scale)
1091+
if quant_transpose is None:
1092+
fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1093+
weight_list, transpose=False
1094+
)
1095+
setattr(weight_obj, "fp8_weight_stacked", fp8_weight)
1096+
setattr(weight_obj, "fp8_scale_stacked", fp8_scale)
10971097

1098-
if quant_transpose:
1099-
# Quantize with transpose
11001098
fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(
11011099
weight_list, transpose=True
11021100
)
11031101
setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t)
11041102
setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t)
1103+
elif quant_transpose is False:
1104+
# Only quantize without transpose
1105+
fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1106+
weight_list, transpose=False
1107+
)
1108+
setattr(weight_obj, "fp8_weight_stacked", fp8_weight)
1109+
setattr(weight_obj, "fp8_scale_stacked", fp8_scale)
1110+
elif quant_transpose is True:
1111+
# Only quantize with transpose
1112+
fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1113+
weight_list, transpose=True
1114+
)
1115+
setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t)
1116+
setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t)
1117+
else:
1118+
raise ValueError("Invalid value for `quant_transpose`.")
11051119

11061120
if batch_mode:
11071121
# Batch mode: process all experts' weights together
@@ -1830,7 +1844,7 @@ def __init__(
18301844
)
18311845
set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn")
18321846

1833-
def fp8_quant_weight(self, quant_transpose=True):
1847+
def fp8_quant_weight(self, quant_transpose=None):
18341848
cache_fp8_weight(self.q_up_weight, quant_transpose=quant_transpose)
18351849
cache_fp8_weight(self.kv_up_weight, quant_transpose=quant_transpose)
18361850

@@ -1959,7 +1973,7 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
19591973
self.eps = eps
19601974
set_parameter_color([self.q_down_weight], "rms_linear")
19611975

1962-
def fp8_quant_weight(self, quant_transpose=True):
1976+
def fp8_quant_weight(self, quant_transpose=None):
19631977
cache_fp8_weight(self.q_down_weight, quant_transpose=quant_transpose)
19641978

19651979
def forward(self, x):
@@ -2124,7 +2138,7 @@ def linear_dtype_gaurd():
21242138

21252139
self.attn_func = scaled_dot_product_attention
21262140

2127-
def fp8_quant_weight(self, quant_transpose=True):
2141+
def fp8_quant_weight(self, quant_transpose=None):
21282142

21292143
if DSV3_USE_ATTEN_RECOMPUTE:
21302144
self.o_proj.fp8_quant_weight(quant_transpose=quant_transpose)
@@ -2356,7 +2370,7 @@ def __init__(
23562370
else:
23572371
self.mlp = DeepseekV2MLPClass(config, recompute_fwd_gate_up=True)
23582372

2359-
def fp8_quant_weight(self, batch_mode=False, quant_transpose=True):
2373+
def fp8_quant_weight(self, batch_mode=False, quant_transpose=None):
23602374
"""fp8_quant_weight"""
23612375
if isinstance(self.mlp, DeepseekV2MoE):
23622376
# logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")

paddlenlp/transformers/fp8_utils.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def fused_stack_quant(expert_weight_list, transpose=False):
9292
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True)
9393
elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked"):
9494
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False)
95+
elif transpose is False and hasattr(expert_weight_list[0], "fp8_weight_stacked_transpose"):
96+
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True)
9597
else:
9698
w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose)
9799
return w, scale
@@ -114,6 +116,8 @@ def weight_quant(weight, transpose=False):
114116
else:
115117
if hasattr(weight, "fp8_weight"):
116118
return weight.fp8_weight, weight.fp8_scale
119+
elif hasattr(weight, "fp8_weight_transpose"):
120+
return weight.fp8_weight_transpose.T.contiguous(), weight.fp8_scale_transpose.T.contiguous()
117121
else:
118122
return paddle.incubate.nn.functional.fp8_quant_blockwise(
119123
weight,
@@ -596,23 +600,33 @@ def forward(self, x):
596600
return FP8LinearFunction.apply(x, self, keep_x=False)
597601

598602

599-
def cache_fp8_weight(weight, quant_transpose=True):
600-
if hasattr(weight, "fp8_weight"):
603+
def cache_fp8_weight(weight, quant_transpose=None):
604+
if hasattr(weight, "fp8_weight") or hasattr(weight, "fp8_weight_transpose"):
601605
return
602-
603-
if quant_transpose:
606+
if quant_transpose is None:
604607
w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
605608
weight,
606609
output_scale_transpose=False,
607610
quant_method="128x128",
608611
input_transpose=True,
609612
return_transpose_only=False,
610613
)
614+
611615
setattr(weight, "fp8_weight_transpose", w_t_fp8)
612616
setattr(weight, "fp8_scale_transpose", w_t_scale)
613617
setattr(weight, "fp8_weight", w_fp8)
614618
setattr(weight, "fp8_scale", w_scale)
615-
else:
619+
elif quant_transpose is True:
620+
w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
621+
weight,
622+
output_scale_transpose=False,
623+
quant_method="128x128",
624+
input_transpose=True,
625+
return_transpose_only=True,
626+
)
627+
setattr(weight, "fp8_weight_transpose", w_t_fp8)
628+
setattr(weight, "fp8_scale_transpose", w_t_scale)
629+
elif quant_transpose is False:
616630
w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
617631
weight,
618632
output_scale_transpose=False,
@@ -622,6 +636,8 @@ def cache_fp8_weight(weight, quant_transpose=True):
622636
)
623637
setattr(weight, "fp8_weight", w_fp8)
624638
setattr(weight, "fp8_scale", w_scale)
639+
else:
640+
raise ValueError("quant_transpose must be either True, False or None.")
625641

626642

627643
class FP8KeepXLinear(paddle.nn.Layer):
@@ -636,7 +652,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
636652
)
637653
set_parameter_color([self.weight], "attn_out_project")
638654

639-
def fp8_quant_weight(self, quant_transpose=True):
655+
def fp8_quant_weight(self, quant_transpose=None):
640656
cache_fp8_weight(self.weight, quant_transpose=quant_transpose)
641657

642658
def forward(self, x):
@@ -798,7 +814,7 @@ def __init__(
798814
is_bias=False,
799815
)
800816

801-
def fp8_quant_weight(self, quant_transpose=True):
817+
def fp8_quant_weight(self, quant_transpose=None):
802818
cache_fp8_weight(self.w1, quant_transpose)
803819
cache_fp8_weight(self.w2, quant_transpose)
804820

@@ -980,6 +996,10 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indi
980996
bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]])
981997
bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]])
982998

999+
if hasattr(expert_w2[0], "fp8_weight_stacked_transpose") and not hasattr(expert_w2[0], "fp8_weight_stacked"):
1000+
bw_w2_quant = bw_w2_quant.contiguous().transpose([0, 2, 1]).contiguous()
1001+
bw_w2_scale = bw_w2_scale.contiguous().transpose([0, 2, 1]).contiguous()
1002+
9831003
# compute gemm
9841004
if isinstance(unzipped_grad, tuple):
9851005
(unzipped_grad_fp8, unzipped_grad_scale) = unzipped_grad
@@ -1024,6 +1044,10 @@ def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, d
10241044
bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]])
10251045
bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]])
10261046

1047+
if hasattr(expert_w1[0], "fp8_weight_stacked_transpose") and not hasattr(expert_w1[0], "fp8_weight_stacked"):
1048+
bw_w1_quant = bw_w1_quant.contiguous().transpose([0, 2, 1]).contiguous()
1049+
bw_w1_scale = bw_w1_scale.contiguous().transpose([0, 2, 1]).contiguous()
1050+
10271051
# quant do1
10281052
do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
10291053
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False

0 commit comments

Comments
 (0)