From 6fdedc1c7c087e4565b49c974223fe57c4969521 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 9 Sep 2025 15:16:18 +0000 Subject: [PATCH] add dtype support --- .../quantization/quant_args.py | 16 +++++++++++ .../quantization/quant_scheme.py | 8 ++++++ .../quantization/utils/helpers.py | 27 ++++++------------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353b..d6778bf67 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -172,6 +172,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True): block_structure: Optional[List[int]] = None dynamic: Union[DynamicType, bool] = False actorder: Union[ActivationOrdering, bool, None] = None + scale_dtype: Union[torch.dtype] = None + zp_dtype: Union[torch.dtype] = None observer: Optional[str] = Field( default=None, description=( @@ -262,6 +264,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": actorder = model.actorder dynamic = model.dynamic observer = model.observer + zp_dtype = model.zp_dtype + scale_dtype = model.scale_dtype # infer strategy if strategy is None: @@ -335,9 +339,21 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": # default to minmax for non-dynamic cases observer = "minmax" + # ToDo: check if fp4 + if zp_dtype is None: + zp_dtype = model.pytorch_dtype() + + # ToDo - should be fp8 / uint8 for fp4 + if scale_dtype is None: + scale_dtype = torch.float16 + + # TODO: make it obvious that fp4 does not support asym + # write back modified values model.strategy = strategy model.observer = observer + model.zp_dtype = zp_dtype + model.scale_dtype = scale_dtype return model def pytorch_dtype(self) -> torch.dtype: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c0..1972f3661 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -22,6 +22,7 @@ QuantizationArgs, QuantizationStrategy, QuantizationType, + FP8_E4M3_DATA ) from pydantic import BaseModel, ConfigDict, model_validator @@ -159,6 +160,8 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=16, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype ) ) @@ -171,6 +174,9 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=16, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype + ), input_activations=QuantizationArgs( num_bits=4, @@ -179,6 +185,8 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=DynamicType.LOCAL, group_size=16, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype ), ) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..518d16a12 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -92,22 +92,15 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min - if is_fp4(quantization_args=quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype - else: - zp_dtype = quantization_args.pytorch_dtype() - if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + scales = max_val_pos / (float(bit_range) / 2) - if is_fp4(quantization_args=quantization_args) and global_scale is not None: - # Conditionally scale the generated local scale by a global_scale - scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) - scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min) - scales = scales.to(FP8_E4M3_DATA.dtype) - - else: - scales = max_val_pos / (float(bit_range) / 2) + if global_scale is not None: + scales = global_scale * scales + # TODO: fix to fetch max / min correctly + scales = torch.clamp(scales, max=quantization_args.scale_dtype.max, min=quantization_args.scale_dtype.min) + scales = scales.to(quantization_args.scale_dtype.dtype) # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped if scales.dtype == FP8_E4M3_DATA.dtype: @@ -123,11 +116,6 @@ def calculate_qparams( zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: - if is_fp4(quantization_args=quantization_args): - raise NotImplementedError( - "Asymmetric Quantization is not supported for FP4" - ) - scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = bit_min - (min_vals / scales) @@ -137,7 +125,8 @@ def calculate_qparams( # if casting to int, use round instead of truncate if quantization_args.type == QuantizationType.INT: zero_points = torch.round(zero_points) - zero_points = zero_points.to(zp_dtype) + + zero_points = zero_points.to(quantization_args.zp_dtype) if scales.ndim == 0: scales = scales.reshape(1)