File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed
src/compressed_tensors/quantization Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -53,6 +53,7 @@ class FP4_E2M1_DATA(FloatArgs):
53
53
min = - 6.0
54
54
55
55
@staticmethod
56
+ @torch .compile
56
57
def cast_to_fp4 (x ):
57
58
sign = torch .sign (x )
58
59
x = torch .abs (x )
Original file line number Diff line number Diff line change @@ -81,7 +81,7 @@ def calculate_qparams(
81
81
currently only applied/supported for Fp4
82
82
83
83
:return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated
84
- scale if of dtype FP8
84
+ scale is of dtype FP8
85
85
"""
86
86
# based on the implementations for consuming quantized values,
87
87
# 0.0 must always be representable within the quantized range
@@ -490,7 +490,6 @@ def generate_global_scale(
490
490
attempts to use the entire FP8 dtype range while mapping a per-group max
491
491
to the FP4 max.
492
492
"""
493
- scale_dtype = scale_data .dtype
494
493
tensor_amax = torch .abs (input_tensor .data ).max ().to (dtype )
495
494
global_scale = scale_data .max * quant_data .max / tensor_amax
496
495
return global_scale .to (dtype )
You can’t perform that action at this time.
0 commit comments