Skip to content
Merged
51 changes: 36 additions & 15 deletions auto_round/data_type/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@register_dtype(("block_fp8_sym", "block_fp8", "block_fp8_e4m3"))
def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128, 128), v=0, **kwargs):
def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128, 128), v=0, tensor_min=None, **kwargs):
"""Symmetric quantization using block float8 format.

Args:
Expand All @@ -51,9 +51,12 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,
if tensor_max is None:
max_tensor = tensor.abs().amax(dim=(-2, -1)) * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
assert tensor_min is not None and isinstance(tensor_min, torch.Tensor)
max_tensor = torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device) * max_scale
)
scale = max_tensor / info.max
assert len(scale.shape) == 2, f"Only support 2D group_size, but get {len(scale.shape)}"
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
Expand All @@ -71,7 +74,7 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,


@register_dtype(("fp8_sym", "fp8", "fp8_e4m3"))
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs):
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs):
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand All @@ -98,9 +101,12 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
if tensor_max is None: ##dynamic per-token
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
assert tensor_min is not None and isinstance(tensor_min, torch.Tensor)
max_tensor = torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device) * max_scale
)
scale = max_tensor.to(torch.float32) / info.max
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand All @@ -117,7 +123,7 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **


@register_dtype("fp8_e5m2")
def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs):
def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs):
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand All @@ -140,9 +146,12 @@ def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, *
if tensor_max is None: ##dynamic per-token
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
assert tensor_min is not None and isinstance(tensor_min, torch.Tensor)
max_tensor = torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device) * max_scale
)
scale = max_tensor.to(torch.float32) / info.max
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand Down Expand Up @@ -225,7 +234,7 @@ def quant_fp8_e5m2_unit_scale(tensor, max_scale=1.0, tensor_max=None, group_size


@register_dtype("fp8_gaudi3_sym")
def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, tensor_min=None, **kwargs):
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand All @@ -250,9 +259,15 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
tensor = tensor.reshape(-1, orig_shape[-1])
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale
assert tensor_min is not None and isinstance(tensor_min, torch.Tensor)
max_tensor = (
torch.maximum(tensor_max.clone().detach().abs(), tensor_min.clone().detach().abs()).to(tensor.device)
* max_scale
)
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device) * max_scale
)
scale = max_tensor.to(torch.float32) / fp8_max
min_scaling_factor = float(1.0 / (fp8_max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand All @@ -271,7 +286,9 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
if is_gaudi2():

@register_dtype(("fp8_sym", "fp8", "fp8_e4m3"))
def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **kwargs): # pylint: disable=E0102
def quant_fp8_sym(
tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, tensor_min=None, **kwargs
): # pylint: disable=E0102
"""Symmetric quantization using float8 format.

Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
Expand Down Expand Up @@ -300,9 +317,13 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
if tensor_max is None: ##dynamic per-token
max_tensor = torch.max(torch.abs(tensor), dim=-1)[0] * max_scale
elif isinstance(tensor_max, torch.Tensor):
max_tensor = tensor_max.to(tensor.device) * max_scale
assert tensor_min is not None and isinstance(tensor_min, torch.Tensor)
max_tensor = torch.maximum(tensor_max.abs(), tensor_min.abs()).to(tensor.device) * max_scale
else:
max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale
max_tensor = (
torch.maximum(torch.tensor(tensor_max).abs(), torch.tensor(tensor_min).abs()).to(tensor.device)
* max_scale
)
scale = max_tensor.to(torch.float32) / info.max
min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm
scale = torch.clip(scale, min=min_scaling_factor)
Expand Down
Loading