Skip to content

Commit 4788675

Browse files
authored
wrap fp4 rounding in torch.compile (#331)
1 parent 859fd90 commit 4788675

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class FP4_E2M1_DATA(FloatArgs):
5353
min = -6.0
5454

5555
@staticmethod
56+
@torch.compile
5657
def cast_to_fp4(x):
5758
sign = torch.sign(x)
5859
x = torch.abs(x)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def calculate_qparams(
8181
currently only applied/supported for Fp4
8282
8383
:return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated
84-
scale if of dtype FP8
84+
scale is of dtype FP8
8585
"""
8686
# based on the implementations for consuming quantized values,
8787
# 0.0 must always be representable within the quantized range
@@ -490,7 +490,6 @@ def generate_global_scale(
490490
attempts to use the entire FP8 dtype range while mapping a per-group max
491491
to the FP4 max.
492492
"""
493-
scale_dtype = scale_data.dtype
494493
tensor_amax = torch.abs(input_tensor.data).max().to(dtype)
495494
global_scale = scale_data.max * quant_data.max / tensor_amax
496495
return global_scale.to(dtype)

0 commit comments

Comments
 (0)