Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def _quantize(
if global_scale is not None:
scale = scale.to(global_scale.dtype) / global_scale

scale = scale.to(x.dtype) / torch.iinfo(torch.uint8).max
scaled = x / scale

if zero_point is not None:
Expand Down Expand Up @@ -501,6 +502,8 @@ def _dequantize(
if global_scale is not None:
scale = scale.to(global_scale.dtype) / global_scale

scale = scale.to(torch.float16) / torch.iinfo(torch.uint8).max

dequant_value = x_q.to(scale.dtype)

if zero_point is not None:
Expand All @@ -510,5 +513,4 @@ def _dequantize(

if dtype is not None:
dequant_value = dequant_value.to(dtype)

return dequant_value
6 changes: 5 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,11 @@ def _initialize_scale_zero_point(
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype

if is_fp4(quantization_args=quantization_args):
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
if quantization_args.group_size == 16:
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
else:
# group_size 32
scale_dtype = zp_dtype = torch.uint8
else:
# TODO: consider erroring out in the future as if the dtype if not one of these,
# there is likely bug
Expand Down
61 changes: 45 additions & 16 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def is_fp4(quantization_args: QuantizationArgs):
and quantization_args.type == QuantizationType.FLOAT
)

def get_power_of_two(x):
powers = torch.tensor([0, 1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8).to(x.device)

# Expand and compute distances
diff = (x.unsqueeze(-1).to(torch.int16) - powers.to(torch.int16)).abs()

# Find nearest index
nearest_idx = diff.argmin(dim=-1)

return powers[nearest_idx]



def calculate_qparams(
min_vals: Tensor,
Expand Down Expand Up @@ -93,33 +105,50 @@ def calculate_qparams(
bit_range = bit_max - bit_min

if is_fp4(quantization_args=quantization_args):
zp_dtype = FP8_E4M3_DATA.dtype
if quantization_args.group_size == 16:
zp_dtype = FP8_E4M3_DATA.dtype
else:
# group_size 32
zp_dtype = torch.uint8
else:
zp_dtype = quantization_args.pytorch_dtype()

if quantization_args.symmetric:
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))

if is_fp4(quantization_args=quantization_args) and global_scale is not None:
# Conditionally scale the generated local scale by a global_scale
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
scales = scales.to(FP8_E4M3_DATA.dtype)
if is_fp4(quantization_args=quantization_args):
if global_scale is not None:
# Conditionally scale the generated local scale by a global_scale
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
scales = torch.clamp(
scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min
)
scales = scales.to(FP8_E4M3_DATA.dtype)
else:

scales = torch.iinfo(torch.uint8).max * (max_val_pos) # / FP4_E2M1_DATA.max)
scales = torch.clamp(
scales,
max=torch.iinfo(torch.uint8).max,
min=torch.iinfo(torch.uint8).min,
)
scales = scales.to(torch.uint8)
scales = get_power_of_two(scales)

else:
scales = max_val_pos / (float(bit_range) / 2)

# TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
if scales.dtype == FP8_E4M3_DATA.dtype:
# torch.clamp not supported for FP8
# use the next largest fp8 value from 0
scales = torch.where(
scales == 0,
torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
scales,
)
else:
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
# if scales.dtype == FP8_E4M3_DATA.dtype:
# torch.clamp not supported for FP8
# use the next largest fp8 value from 0
# scales = torch.where(
# scales == 0,
# torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
# scales,
# )
# else:
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)

zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
else:
Expand Down