Skip to content

Commit b78453c

Browse files
fix fp8 accuracy drop issue (#1611)
Signed-off-by: Mengni Wang <mengni.wang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 72d5c4d commit b78453c

File tree

3 files changed

+73
-22
lines changed

3 files changed

+73
-22
lines changed

auto_round/compressors/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,10 @@ def __init__(
474474
disable_opt_rtn = False
475475

476476
if self.iters > 0 and is_block_wfp8(self):
477-
logger.warning("RTN is recommended since it shows even better accuracy for block-wise fp8 quantization.")
477+
logger.warning(
478+
"RTN is recommended as it achieves accuracy comparable to tuning for block-wise FP8 quantization "
479+
"while being significantly faster. You can set `--iters 0 --disable_opt_rtn` to enable RTN mode."
480+
)
478481

479482
# Important Note! This is not very robust, do NOT rely on it to do high risky thing
480483
self.is_moe_model = is_moe_model(self.model)

auto_round/data_type/fp8.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
@register_dtype(("block_fp8_sym", "block_fp8", "block_fp8_e4m3"))
29-
def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128, 128), v=0, **kwargs):
29+
def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128, 128), v=0, tensor_min=None, **kwargs):
3030
"""Symmetric quantization using block float8 format.
3131
3232
Args:
@@ -51,9 +51,18 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,
5151
if tensor_max is None:
5252
max_tensor = tensor.abs().amax(dim=(-2, -1)) * max_scale
5353
elif isinstance(tensor_max, torch.Tensor):
54-
max_tensor = tensor_max.to(tensor.device) * max_scale
54+
max_tensor = (
55+
tensor_max.to(tensor.device) * max_scale
56+
if tensor_min is None
57+
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
58+
)
5559
else:
56-
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
60+
max_tensor = (
61+
torch.tensor(tensor_max).to(tensor.device) * max_scale
62+
if tensor_min is None
63+
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
64+
* max_scale
65+
)
5766
scale = max_tensor / info.max
5867
assert len(scale.shape) == 2, f"Only support 2D group_size, but get {len(scale.shape)}"
5968
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
@@ -71,7 +80,7 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,
7180

7281

7382
@register_dtype(("fp8_sym", "fp8", "fp8_e4m3"))
74-
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs):
83+
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs):
7584
"""Symmetric quantization using float8 format.
7685
7786
Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -98,9 +107,18 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
98107
if tensor_max is None: ##dynamic per-token
99108
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
100109
elif isinstance(tensor_max, torch.Tensor):
101-
max_tensor = tensor_max.to(tensor.device) * max_scale
110+
max_tensor = (
111+
tensor_max.to(tensor.device) * max_scale
112+
if tensor_min is None
113+
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
114+
)
102115
else:
103-
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
116+
max_tensor = (
117+
torch.tensor(tensor_max).to(tensor.device) * max_scale
118+
if tensor_min is None
119+
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
120+
* max_scale
121+
)
104122
scale = max_tensor.to(torch.float32) / info.max
105123
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
106124
scale = torch.clip(scale, min=min_scaling_factor)
@@ -117,7 +135,7 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
117135

118136

119137
@register_dtype("fp8_e5m2")
120-
def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs):
138+
def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs):
121139
"""Symmetric quantization using float8 format.
122140
123141
Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -140,9 +158,18 @@ def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, *
140158
if tensor_max is None: ##dynamic per-token
141159
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
142160
elif isinstance(tensor_max, torch.Tensor):
143-
max_tensor = tensor_max.to(tensor.device) * max_scale
161+
max_tensor = (
162+
tensor_max.to(tensor.device) * max_scale
163+
if tensor_min is None
164+
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
165+
)
144166
else:
145-
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
167+
max_tensor = (
168+
torch.tensor(tensor_max).to(tensor.device) * max_scale
169+
if tensor_min is None
170+
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
171+
* max_scale
172+
)
146173
scale = max_tensor.to(torch.float32) / info.max
147174
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
148175
scale = torch.clip(scale, min=min_scaling_factor)
@@ -225,7 +252,7 @@ def quant_fp8_e5m2_unit_scale(tensor, max_scale=1.0, tensor_max=None, group_size
225252

226253

227254
@register_dtype("fp8_gaudi3_sym")
228-
def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
255+
def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, tensor_min=None, **kwargs):
229256
"""Symmetric quantization using float8 format.
230257
231258
Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -250,9 +277,19 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
250277
tensor = tensor.reshape(-1, orig_shape[-1])
251278
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
252279
elif isinstance(tensor_max, torch.Tensor):
253-
max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale
280+
max_tensor = (
281+
tensor_max.to(tensor.device) * max_scale
282+
if tensor_min is None
283+
else torch.maximum(tensor_max.clone().detach().abs(), tensor_min.clone().detach().abs()).to(tensor.device)
284+
* max_scale
285+
)
254286
else:
255-
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
287+
max_tensor = (
288+
torch.tensor(tensor_max).to(tensor.device) * max_scale
289+
if tensor_min is None
290+
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
291+
* max_scale
292+
)
256293
scale = max_tensor.to(torch.float32) / fp8_max
257294
min_scaling_factor = float(1.0 / (fp8_max * 512.0)) ##copy from vllm
258295
scale = torch.clip(scale, min=min_scaling_factor)
@@ -271,7 +308,9 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
271308
if is_gaudi2():
272309

273310
@register_dtype(("fp8_sym", "fp8", "fp8_e4m3"))
274-
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs): # pylint: disable=E0102
311+
def quant_fp8_sym(
312+
tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs
313+
): # pylint: disable=E0102
275314
"""Symmetric quantization using float8 format.
276315
277316
Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -300,9 +339,18 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
300339
if tensor_max is None: ##dynamic per-token
301340
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
302341
elif isinstance(tensor_max, torch.Tensor):
303-
max_tensor = tensor_max.to(tensor.device) * max_scale
342+
max_tensor = (
343+
tensor_max.to(tensor.device) * max_scale
344+
if tensor_min is None
345+
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
346+
)
304347
else:
305-
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
348+
max_tensor = (
349+
torch.tensor(tensor_max).to(tensor.device) * max_scale
350+
if tensor_min is None
351+
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
352+
* max_scale
353+
)
306354
scale = max_tensor.to(torch.float32) / info.max
307355
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
308356
scale = torch.clip(scale, min=min_scaling_factor)

docs/fp8_block_acc.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ lm_eval --model vllm --model_args pretrained=model_path,add_bos_token=true,tenso
1818

1919
|Scheme |arc_challenge|arc_easy|boolq |hellaswag|lambada_openai|openbookqa|piqa |truthfulqa_mc1|winogrande|gsm8k_llama|mmlu_llama|mmlu_pro_llama|Avg. |
2020
|:---------------|:------------|:-------|:-----|:--------|:-------------|:---------|:-----|:-------------|:---------|:----------|:---------|:-------------|:-----|
21-
|BF16 |0.5341 |0.8232 |0.8428|0.5805 |0.7258 |0.3540 |0.7758|0.3709 |0.7151 |0.6111 |0.6737 |0.3778 |0.6145|
22-
|FP8_BLOCK RTN |0.5239 |0.8182 |0.8483|0.5780 |0.7235 |0.3520 |0.7840|0.3647 |0.7040 |0.6171 |0.6681 |0.3738 |0.6130|
23-
|FP8_BLOCK Tuning|0.5213 |0.8047 |0.8376|0.5834 |0.7099 |0.3480 |0.7764|0.3684 |0.7245 |0.6224 |0.6654 |0.3690 |0.6109|
21+
|BF16 |0.5358 |0.8232 |0.8459|0.5795 |0.7266 |0.3540 |0.7786|0.3721 |0.7135 |0.7369 |0.6741 |0.4334 |0.6311|
22+
|FP8_BLOCK RTN |0.5316 |0.8173 |0.8468|0.5787 |0.7256 |0.3560 |0.7824|0.3684 |0.7103 |0.7369 |0.6661 |0.4358 |0.6297|
23+
|FP8_BLOCK Tuning|0.5401 |0.8228 |0.8459|0.5786 |0.7275 |0.3480 |0.7769|0.3684 |0.7103 |0.7437 |0.6727 |0.4315 |0.6305|
2424

2525

2626
## Qwen3-8B
2727

2828
|Scheme |arc_challenge|arc_easy|boolq |hellaswag|lambada_openai|openbookqa|piqa |truthfulqa_mc1|winogrande|gsm8k |mmlu |mmlu_pro|Avg. |
2929
|:---------------|:------------|:-------|:-----|:--------|:-------------|:---------|:-----|:-------------|:---------|:-----|:-----|:-------|:-----|
30-
|BF16 |0.5589 |0.8329 |0.8661|0.5719 |0.6501 |0.3140 |0.7688|0.3647 |0.6772 |0.8726|0.7297|0.6218 |0.6524|
31-
|FP8_BLOCK RTN |0.5444 |0.8350 |0.8679|0.5704 |0.6466 |0.3060 |0.7644|0.3611 |0.6914 |0.8741|0.7268|0.6242 |0.6510|
32-
|FP8_BLOCK Tuning|0.5486 |0.8308 |0.8639|0.5691 |0.6513 |0.3080 |0.7622|0.3586 |0.6882 |0.8673|0.7272|0.6139 |0.6491|
30+
|BF16 |0.5580 |0.8342 |0.8667|0.5712 |0.6517 |0.3120 |0.7688|0.3647 |0.6788 |0.8726|0.7292|0.6214 |0.6524|
31+
|FP8_BLOCK RTN |0.5538 |0.8350 |0.8691|0.5713 |0.6468 |0.3140 |0.7661|0.3721 |0.6780 |0.8704|0.7269|0.6204 |0.6520|
32+
|FP8_BLOCK Tuning|0.5520 |0.8384 |0.8703|0.5710 |0.6519 |0.3260 |0.7671|0.3635 |0.6725 |0.8681|0.7302|0.6168 |0.6523|
3333

3434

0 commit comments

Comments
 (0)