diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 58a16dfba..9d1970156 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -468,6 +468,7 @@ def _quantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + scale = scale.to(x.dtype) / torch.iinfo(torch.uint8).max scaled = x / scale if zero_point is not None: @@ -501,6 +502,8 @@ def _dequantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + scale = scale.to(torch.float16) / torch.iinfo(torch.uint8).max + dequant_value = x_q.to(scale.dtype) if zero_point is not None: @@ -510,5 +513,4 @@ def _dequantize( if dtype is not None: dequant_value = dequant_value.to(dtype) - return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index c9430e9ec..2fd0a6f8f 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -216,7 +216,11 @@ def _initialize_scale_zero_point( scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype if is_fp4(quantization_args=quantization_args): - scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype + if quantization_args.group_size == 16: + scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype + else: + # group_size 32 + scale_dtype = zp_dtype = torch.uint8 else: # TODO: consider erroring out in the future as if the dtype if not one of these, # there is likely bug diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..bbdaadfb5 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -63,6 +63,18 @@ def is_fp4(quantization_args: QuantizationArgs): and quantization_args.type == QuantizationType.FLOAT ) +def get_power_of_two(x): + powers = torch.tensor([0, 1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8).to(x.device) + + # Expand and compute distances + diff = (x.unsqueeze(-1).to(torch.int16) - powers.to(torch.int16)).abs() + + # Find nearest index + nearest_idx = diff.argmin(dim=-1) + + return powers[nearest_idx] + + def calculate_qparams( min_vals: Tensor, @@ -93,33 +105,50 @@ def calculate_qparams( bit_range = bit_max - bit_min if is_fp4(quantization_args=quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype + if quantization_args.group_size == 16: + zp_dtype = FP8_E4M3_DATA.dtype + else: + # group_size 32 + zp_dtype = torch.uint8 else: zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - if is_fp4(quantization_args=quantization_args) and global_scale is not None: - # Conditionally scale the generated local scale by a global_scale - scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) - scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min) - scales = scales.to(FP8_E4M3_DATA.dtype) + if is_fp4(quantization_args=quantization_args): + if global_scale is not None: + # Conditionally scale the generated local scale by a global_scale + scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) + scales = torch.clamp( + scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min + ) + scales = scales.to(FP8_E4M3_DATA.dtype) + else: + + scales = torch.iinfo(torch.uint8).max * (max_val_pos) # / FP4_E2M1_DATA.max) + scales = torch.clamp( + scales, + max=torch.iinfo(torch.uint8).max, + min=torch.iinfo(torch.uint8).min, + ) + scales = scales.to(torch.uint8) + scales = get_power_of_two(scales) else: scales = max_val_pos / (float(bit_range) / 2) # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped - if scales.dtype == FP8_E4M3_DATA.dtype: - # torch.clamp not supported for FP8 - # use the next largest fp8 value from 0 - scales = torch.where( - scales == 0, - torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), - scales, - ) - else: - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + # if scales.dtype == FP8_E4M3_DATA.dtype: + # torch.clamp not supported for FP8 + # use the next largest fp8 value from 0 + # scales = torch.where( + # scales == 0, + # torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), + # scales, + # ) + # else: + # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: