Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def __init__(
disable_opt_rtn = False

if self.iters > 0 and is_block_wfp8(self):
logger.warning("RTN is recommended since it shows even better accuracy for block-wise fp8 quantization.")
logger.warning("RTN is recommended since it shows comparable accuracy to tuning for block-wise fp8 quantization.")

# Important Note! This is not very robust, do NOT rely on it to do high risky thing
self.is_moe_model = is_moe_model(self.model)
Expand Down
78 changes: 63 additions & 15 deletions auto_round/data_type/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@register_dtype(("block_fp8_sym", "block_fp8", "block_fp8_e4m3"))
def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128, 128), v=0, **kwargs):
def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128, 128), v=0, tensor_min=None, **kwargs):
"""Symmetric quantization using block float8 format.

Args:
Expand All @@ -51,9 +51,18 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,
if tensor_max is None:
max_tensor = tensor.abs().amax(dim=(-2, -1)) * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
max_tensor = (
tensor_max.to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
)
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.tensor(tensor_max).to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
* max_scale
)
scale = max_tensor / info.max
assert len(scale.shape) == 2, f"Only support 2D group_size, but get {len(scale.shape)}"
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
Expand All @@ -71,7 +80,7 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,


@register_dtype(("fp8_sym", "fp8", "fp8_e4m3"))
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs):
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs):
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand All @@ -98,9 +107,18 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
if tensor_max is None: ##dynamic per-token
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
max_tensor = (
tensor_max.to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
)
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.tensor(tensor_max).to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
* max_scale
)
scale = max_tensor.to(torch.float32) / info.max
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand All @@ -117,7 +135,7 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **


@register_dtype("fp8_e5m2")
def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs):
def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs):
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand All @@ -140,9 +158,18 @@ def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, *
if tensor_max is None: ##dynamic per-token
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
max_tensor = (
tensor_max.to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
)
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.tensor(tensor_max).to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
* max_scale
)
scale = max_tensor.to(torch.float32) / info.max
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand Down Expand Up @@ -225,7 +252,7 @@ def quant_fp8_e5m2_unit_scale(tensor, max_scale=1.0, tensor_max=None, group_size


@register_dtype("fp8_gaudi3_sym")
def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, tensor_min=None, **kwargs):
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand All @@ -250,9 +277,19 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
tensor = tensor.reshape(-1, orig_shape[-1])
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale
max_tensor = (
tensor_max.to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(tensor_max.clone().detach().abs(), tensor_min.clone().detach().abs()).to(tensor.device)
* max_scale
)
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.tensor(tensor_max).to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
* max_scale
)
scale = max_tensor.to(torch.float32) / fp8_max
min_scaling_factor = float(1.0 / (fp8_max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand All @@ -271,7 +308,9 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
if is_gaudi2():

@register_dtype(("fp8_sym", "fp8", "fp8_e4m3"))
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs): # pylint: disable=E0102
def quant_fp8_sym(
tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs
): # pylint: disable=E0102
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand Down Expand Up @@ -300,9 +339,18 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
if tensor_max is None: ##dynamic per-token
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
max_tensor = (
tensor_max.to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
)
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.tensor(tensor_max).to(tensor.device) * max_scale
if tensor_min is None
else torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
* max_scale
)
scale = max_tensor.to(torch.float32) / info.max
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand Down
12 changes: 6 additions & 6 deletions docs/fp8_block_acc.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ lm_eval --model vllm --model_args pretrained=model_path,add_bos_token=true,tenso

|Scheme |arc_challenge|arc_easy|boolq |hellaswag|lambada_openai|openbookqa|piqa |truthfulqa_mc1|winogrande|gsm8k_llama|mmlu_llama|mmlu_pro_llama|Avg. |
|:---------------|:------------|:-------|:-----|:--------|:-------------|:---------|:-----|:-------------|:---------|:----------|:---------|:-------------|:-----|
|BF16 |0.5341 |0.8232 |0.8428|0.5805 |0.7258 |0.3540 |0.7758|0.3709 |0.7151 |0.6111 |0.6737 |0.3778 |0.6145|
|FP8_BLOCK RTN |0.5239 |0.8182 |0.8483|0.5780 |0.7235 |0.3520 |0.7840|0.3647 |0.7040 |0.6171 |0.6681 |0.3738 |0.6130|
|FP8_BLOCK Tuning|0.5213 |0.8047 |0.8376|0.5834 |0.7099 |0.3480 |0.7764|0.3684 |0.7245 |0.6224 |0.6654 |0.3690 |0.6109|
|BF16 |0.5358 |0.8232 |0.8459|0.5795 |0.7266 |0.3540 |0.7786|0.3721 |0.7135 |0.7369 |0.6741 |0.4334 |0.6311|
|FP8_BLOCK RTN |0.5316 |0.8173 |0.8468|0.5787 |0.7256 |0.3560 |0.7824|0.3684 |0.7103 |0.7369 |0.6661 |0.4358 |0.6297|
|FP8_BLOCK Tuning|0.5401 |0.8228 |0.8459|0.5786 |0.7275 |0.3480 |0.7769|0.3684 |0.7103 |0.7437 |0.6727 |0.4315 |0.6305|


## Qwen3-8B

|Scheme |arc_challenge|arc_easy|boolq |hellaswag|lambada_openai|openbookqa|piqa |truthfulqa_mc1|winogrande|gsm8k |mmlu |mmlu_pro|Avg. |
|:---------------|:------------|:-------|:-----|:--------|:-------------|:---------|:-----|:-------------|:---------|:-----|:-----|:-------|:-----|
|BF16 |0.5589 |0.8329 |0.8661|0.5719 |0.6501 |0.3140 |0.7688|0.3647 |0.6772 |0.8726|0.7297|0.6218 |0.6524|
|FP8_BLOCK RTN |0.5444 |0.8350 |0.8679|0.5704 |0.6466 |0.3060 |0.7644|0.3611 |0.6914 |0.8741|0.7268|0.6242 |0.6510|
|FP8_BLOCK Tuning|0.5486 |0.8308 |0.8639|0.5691 |0.6513 |0.3080 |0.7622|0.3586 |0.6882 |0.8673|0.7272|0.6139 |0.6491|
|BF16 |0.5580 |0.8342 |0.8667|0.5712 |0.6517 |0.3120 |0.7688|0.3647 |0.6788 |0.8726|0.7292|0.6214 |0.6524|
|FP8_BLOCK RTN |0.5538 |0.8350 |0.8691|0.5713 |0.6468 |0.3140 |0.7661|0.3721 |0.6780 |0.8704|0.7269|0.6204 |0.6520|
|FP8_BLOCK Tuning|0.5520 |0.8384 |0.8703|0.5710 |0.6519 |0.3260 |0.7671|0.3635 |0.6725 |0.8681|0.7302|0.6168 |0.6523|


Loading