Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
block_structure: Optional[List[int]] = None
dynamic: Union[DynamicType, bool] = False
actorder: Union[ActivationOrdering, bool, None] = None
scale_dtype: Union[torch.dtype] = None
zp_dtype: Union[torch.dtype] = None
observer: Optional[str] = Field(
default=None,
description=(
Expand Down Expand Up @@ -262,6 +264,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
actorder = model.actorder
dynamic = model.dynamic
observer = model.observer
zp_dtype = model.zp_dtype
scale_dtype = model.scale_dtype

# infer strategy
if strategy is None:
Expand Down Expand Up @@ -335,9 +339,21 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
# default to minmax for non-dynamic cases
observer = "minmax"

# ToDo: check if fp4
if zp_dtype is None:
zp_dtype = model.pytorch_dtype()

# ToDo - should be fp8 / uint8 for fp4
if scale_dtype is None:
scale_dtype = torch.float16

# TODO: make it obvious that fp4 does not support asym

# write back modified values
model.strategy = strategy
model.observer = observer
model.zp_dtype = zp_dtype
model.scale_dtype = scale_dtype
return model

def pytorch_dtype(self) -> torch.dtype:
Expand Down
8 changes: 8 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
FP8_E4M3_DATA
)
from pydantic import BaseModel, ConfigDict, model_validator

Expand Down Expand Up @@ -159,6 +160,8 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=False,
group_size=16,
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype
)
)

Expand All @@ -171,6 +174,9 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=False,
group_size=16,
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype

),
input_activations=QuantizationArgs(
num_bits=4,
Expand All @@ -179,6 +185,8 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=DynamicType.LOCAL,
group_size=16,
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype
),
)

Expand Down
27 changes: 8 additions & 19 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,15 @@ def calculate_qparams(
bit_min, bit_max = calculate_range(quantization_args, device)
bit_range = bit_max - bit_min

if is_fp4(quantization_args=quantization_args):
zp_dtype = FP8_E4M3_DATA.dtype
else:
zp_dtype = quantization_args.pytorch_dtype()

if quantization_args.symmetric:
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
scales = max_val_pos / (float(bit_range) / 2)

if is_fp4(quantization_args=quantization_args) and global_scale is not None:
# Conditionally scale the generated local scale by a global_scale
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
scales = scales.to(FP8_E4M3_DATA.dtype)

else:
scales = max_val_pos / (float(bit_range) / 2)
if global_scale is not None:
scales = global_scale * scales
# TODO: fix to fetch max / min correctly
scales = torch.clamp(scales, max=quantization_args.scale_dtype.max, min=quantization_args.scale_dtype.min)
scales = scales.to(quantization_args.scale_dtype.dtype)

# TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
if scales.dtype == FP8_E4M3_DATA.dtype:
Expand All @@ -123,11 +116,6 @@ def calculate_qparams(

zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
else:
if is_fp4(quantization_args=quantization_args):
raise NotImplementedError(
"Asymmetric Quantization is not supported for FP4"
)

scales = (max_vals - min_vals) / float(bit_range)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
zero_points = bit_min - (min_vals / scales)
Expand All @@ -137,7 +125,8 @@ def calculate_qparams(
# if casting to int, use round instead of truncate
if quantization_args.type == QuantizationType.INT:
zero_points = torch.round(zero_points)
zero_points = zero_points.to(zp_dtype)

zero_points = zero_points.to(quantization_args.zp_dtype)

if scales.ndim == 0:
scales = scales.reshape(1)
Expand Down
Loading