Skip to content

Commit 5fe4ebd

Browse files
authored
Misc fixes to prepare for adding Float8Tensor (#2603)
Summary: * Moved some float8 related util function to torchao.float8.inference * renamed _choose_qparams_affine_float8 to _choose_scale_float8 * added hp_value_lb and hp_value_ub to _choose_scale_float8 * added `__all__` to torchao/core/config.py Test Plan: pytest test/dtypes/test_affine_quantized_float.py -k test_choose_scale_float8_bounds Reviewers: Subscribers: Tasks: Tags:
1 parent bdf4598 commit 5fe4ebd

File tree

7 files changed

+155
-82
lines changed

7 files changed

+155
-82
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from torchao.quantization.quant_primitives import (
4444
MappingType,
45-
_choose_qparams_affine_float8,
45+
_choose_scale_float8,
4646
_dequantize_affine_float8,
4747
_quantize_affine_float8,
4848
choose_qparams_affine,
@@ -350,6 +350,49 @@ def test_mm_float8dq_per_row(
350350
error = compute_error(ref_output, quant_output)
351351
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
352352

353+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
354+
@unittest.skipIf(
355+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
356+
)
357+
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358+
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359+
def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
360+
block_size = ()
361+
device = "cuda"
362+
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)
363+
364+
# testing upper bounds
365+
input_tensor[0][0] = 2000
366+
scale_ref = _choose_scale_float8(
367+
input_tensor, float8_dtype=float8_dtype, block_size=block_size
368+
)
369+
370+
hp_value_ub = 1200
371+
scale_with_ub = _choose_scale_float8(
372+
input_tensor,
373+
float8_dtype=float8_dtype,
374+
block_size=block_size,
375+
hp_value_ub=hp_value_ub,
376+
)
377+
# since scale = abs_max / quant_max, larger abs_max means scale is larger
378+
self.assertTrue(scale_ref > scale_with_ub)
379+
380+
# tesing lower bounds settings
381+
# making sure that abs is on the scale of 1e-20, so hp_value_lb can take effect
382+
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) * 1e-20
383+
scale_ref = _choose_scale_float8(
384+
input_tensor, float8_dtype=float8_dtype, block_size=block_size
385+
)
386+
hp_value_lb = 1e-12
387+
scale_with_lb = _choose_scale_float8(
388+
input_tensor,
389+
float8_dtype=float8_dtype,
390+
block_size=block_size,
391+
hp_value_lb=hp_value_lb,
392+
)
393+
# since scale = abs_max / quant_max, larger abs_max means scale is larger
394+
self.assertTrue(scale_ref < scale_with_lb)
395+
353396
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
354397
@unittest.skipIf(
355398
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -364,7 +407,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
364407
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)
365408

366409
# Choose quantization parameters
367-
scale = _choose_qparams_affine_float8(
410+
scale = _choose_scale_float8(
368411
input_tensor, float8_dtype=float8_dtype, block_size=block_size
369412
)
370413

@@ -395,7 +438,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
395438
block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim
396439

397440
# Choose quantization parameters
398-
scale = _choose_qparams_affine_float8(
441+
scale = _choose_scale_float8(
399442
input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size
400443
)
401444

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2102,7 +2102,7 @@ def forward(self, x):
21022102
ep = torch.export.export(model, (inp,))
21032103
print(ep)
21042104
FileCheck().check_count(
2105-
"torch.ops.torchao.choose_qparams_affine_float8.default", 1, exactly=True
2105+
"torch.ops.torchao.choose_scale_float8.default", 1, exactly=True
21062106
).run(str(ep.graph))
21072107

21082108

torchao/core/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212

1313
import torch
1414

15+
__all__ = [
16+
"AOBaseConfig",
17+
"VersionMismatchError",
18+
"config_from_dict",
19+
"config_to_dict",
20+
"ALLOWED_AO_MODULES",
21+
]
22+
1523

1624
class AOBaseConfig(abc.ABC):
1725
"""

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
MappingType,
2020
ZeroPointDomain,
2121
_choose_qparams_affine_dont_preserve_zero,
22-
_choose_qparams_affine_float8,
2322
_choose_qparams_affine_floatx,
2423
_choose_qparams_affine_tinygemm,
2524
_choose_qparams_and_quantize_affine_hqq,
25+
_choose_scale_float8,
2626
_dequantize_affine_float8,
2727
_dequantize_affine_floatx,
2828
_dequantize_affine_no_zero_point,
@@ -462,7 +462,7 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
scale = _choose_qparams_affine_float8(
465+
scale = _choose_scale_float8(
466466
input_float, float8_dtype=target_dtype, block_size=block_size
467467
)
468468
data = _quantize_affine_float8(input_float, scale, target_dtype)

torchao/dtypes/floatx/float8_layout.py

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from torchao.float8.inference import (
2121
Float8MMConfig,
2222
_is_rowwise_scaled,
23+
_slice_scale_for_dimension,
2324
addmm_float8_unwrapped_inference,
2425
preprocess_data,
26+
preprocess_scale,
2527
)
2628
from torchao.utils import _is_float8_type, fill_defaults
2729

@@ -299,56 +301,6 @@ def _(func, types, args, kwargs):
299301
)
300302

301303

302-
def _slice_scale_for_dimension(
303-
scale: torch.Tensor,
304-
data_shape: List[int],
305-
dim: int,
306-
start: int,
307-
end: int,
308-
step: int,
309-
) -> torch.Tensor:
310-
"""
311-
Slice the scale tensor appropriately based on the data tensor slicing.
312-
313-
This function calculates how the scale should be sliced when the data tensor
314-
is sliced along a given dimension, taking into account the block structure.
315-
"""
316-
# Unsupported case for now, this would be 1 scale per data element
317-
if scale.shape == data_shape:
318-
return aten.slice.Tensor(scale, dim, start, end, step)
319-
320-
# Reconstruct block sizes based on data shape and scale shape
321-
block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape)))
322-
323-
if dim >= len(block_sizes):
324-
# Slicing beyond the dimensions we care about
325-
return scale
326-
327-
block_size_for_dim = block_sizes[dim]
328-
329-
if block_size_for_dim == 1:
330-
# Scale is per-element along this dimension
331-
# Slice away as normal
332-
return aten.slice.Tensor(scale, dim, start, end, step)
333-
else:
334-
# There is blocking in this dimension
335-
# Calculate which scale elements correspond to the sliced data
336-
scale_start = start // block_size_for_dim if start is not None else None
337-
scale_end = (
338-
(end + block_size_for_dim - 1) // block_size_for_dim
339-
if end is not None
340-
else None
341-
)
342-
343-
# Error on Step > 1
344-
if step > 1:
345-
raise NotImplementedError(
346-
"Slicing with step > 1 is not implemented for scale tensors."
347-
)
348-
349-
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
350-
351-
352304
##########################
353305
# Float8 Dispatch Kernels
354306
##########################
@@ -370,24 +322,6 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
370322
return check_aqt(input_tensor) and check_aqt(weight_tensor)
371323

372324

373-
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]):
374-
"""Ensures input tensor is correctly formatted for _scaled_mm"""
375-
376-
# For PerTensor quantization, scale should be a scalar or have shape [1]
377-
if input_scale.numel() == 1:
378-
# Already a scalar, ensure it has the right shape for _scaled_mm
379-
return input_scale.reshape(1, 1)
380-
381-
# For per-row/block quantization, we need to handle the reshaping
382-
input_scale = input_scale.unsqueeze(-1)
383-
384-
# Match: #input_data.reshape(-1, input_data.shape[-1])
385-
if input_scale.dim() > 2:
386-
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
387-
388-
return input_scale
389-
390-
391325
def _linear_fp8_act_fp8_weight_impl(
392326
input_tensor: "AffineQuantizedTensor",
393327
weight_tensor: "AffineQuantizedTensor",

torchao/float8/inference.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Defines an nn module designed to be used during inference
88
"""
99

10-
from typing import NamedTuple, Optional, Tuple, Union
10+
from typing import List, NamedTuple, Optional, Tuple, Union
1111

1212
import torch
1313

@@ -67,6 +67,24 @@ def preprocess_data(
6767
return a_data, b_data
6868

6969

70+
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]):
71+
"""Ensures input tensor is correctly formatted for _scaled_mm"""
72+
73+
# For PerTensor quantization, scale should be a scalar or have shape [1]
74+
if input_scale.numel() == 1:
75+
# Already a scalar, ensure it has the right shape for _scaled_mm
76+
return input_scale.reshape(1, 1)
77+
78+
# For per-row/block quantization, we need to handle the reshaping
79+
input_scale = input_scale.unsqueeze(-1)
80+
81+
# Match: #input_data.reshape(-1, input_data.shape[-1])
82+
if input_scale.dim() > 2:
83+
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
84+
85+
return input_scale
86+
87+
7088
def addmm_float8_unwrapped_inference(
7189
a_data: Tensor,
7290
a_scale: Tensor,
@@ -107,12 +125,75 @@ def addmm_float8_unwrapped_inference(
107125
)
108126

109127

110-
def _is_rowwise_scaled(x) -> bool:
111-
"""Checks if an AQT tensor is rowwise scaled
128+
def _slice_scale_for_dimension(
129+
scale: torch.Tensor,
130+
data_shape: List[int],
131+
dim: int,
132+
start: int,
133+
end: int,
134+
step: int,
135+
) -> torch.Tensor:
136+
"""
137+
Slice the scale tensor appropriately based on the data tensor slicing.
138+
This function calculates how the scale should be sliced when the data tensor
139+
is sliced along a given dimension, taking into account the block structure.
140+
"""
141+
aten = torch.ops.aten
142+
143+
# Unsupported case for now, this would be 1 scale per data element
144+
if scale.shape == data_shape:
145+
return aten.slice.Tensor(scale, dim, start, end, step)
146+
147+
# Reconstruct block sizes based on data shape and scale shape
148+
block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape)))
149+
150+
if dim >= len(block_sizes):
151+
# Slicing beyond the dimensions we care about
152+
return scale
153+
154+
block_size_for_dim = block_sizes[dim]
155+
156+
if block_size_for_dim == 1:
157+
# Scale is per-element along this dimension
158+
# Slice away as normal
159+
return aten.slice.Tensor(scale, dim, start, end, step)
160+
else:
161+
# There is blocking in this dimension
162+
# Calculate which scale elements correspond to the sliced data
163+
scale_start = start // block_size_for_dim if start is not None else None
164+
scale_end = (
165+
(end + block_size_for_dim - 1) // block_size_for_dim
166+
if end is not None
167+
else None
168+
)
169+
170+
# Error on Step > 1
171+
if step > 1:
172+
raise NotImplementedError(
173+
"Slicing with step > 1 is not implemented for scale tensors."
174+
)
175+
176+
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
177+
178+
179+
def _is_rowwise_scaled(x: torch.Tensor) -> bool:
180+
"""Checks if a quantized tensor is rowwise scaled
181+
Args:
182+
x: quantized tensor (should have `block_size` attribute)
183+
"""
184+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
185+
return tuple(x.block_size) == (1,) * (x.dim() - 1) + (x.shape[-1],)
186+
187+
188+
def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
189+
"""Checks if a quantized tensor is rowwise scaled
112190
Args:
113-
x: AffineQuantizedTensor tensor
191+
x: quantized tensor (should have `block_size` attribute)
114192
"""
115-
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)
193+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
194+
return all(
195+
x.block_size[i] == -1 or x.block_size[i] == x.shape[i] for i in range(x.ndim)
196+
)
116197

117198

118199
def _normalize_granularity(

torchao/quantization/quant_primitives.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"_choose_qparams_affine_floatx",
3737
"_choose_qparams_and_quantize_affine_hqq",
3838
"_choose_qparams_and_quantize_affine_qqq",
39-
"_choose_qparams_affine_float8",
39+
"_choose_scale_float8",
4040
"_choose_qparams_gguf",
4141
"_quantize_affine_no_zero_point",
4242
"_quantize_affine_tinygemm",
@@ -2180,11 +2180,13 @@ def _dequantize_affine_floatx(
21802180

21812181

21822182
@register_custom_op
2183-
def _choose_qparams_affine_float8(
2183+
def _choose_scale_float8(
21842184
tensor: torch.Tensor,
21852185
block_size: List[int],
21862186
float8_dtype: torch.dtype = torch.float8_e4m3fn,
21872187
scale_dtype: torch.dtype = torch.float32,
2188+
hp_value_lb: Optional[float] = None,
2189+
hp_value_ub: Optional[float] = None,
21882190
) -> torch.Tensor:
21892191
"""
21902192
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
@@ -2194,19 +2196,24 @@ def _choose_qparams_affine_float8(
21942196
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
21952197
scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32).
21962198
block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used.
2199+
hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale
2200+
hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale
21972201
"""
21982202
quant_max = torch.finfo(float8_dtype).max
21992203
# only tensorwise scaling is supported for now:
22002204
if len(block_size) == 0:
22012205
max_abs = tensor.abs().max()
2206+
if hp_value_lb is not None or hp_value_ub is not None:
2207+
max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub)
22022208
scale = max_abs / quant_max
22032209
else:
22042210
shape_for_reduction, reduction_dims = _get_reduction_params(
22052211
block_size, tensor.shape
22062212
)
22072213
tensor_reshaped = tensor.view(shape_for_reduction)
22082214
max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True)
2209-
2215+
if hp_value_lb is not None or hp_value_ub is not None:
2216+
max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub)
22102217
scale = max_abs / quant_max
22112218
# Reshape scale back to match the expected output shape
22122219
# The scale tensor should have the same shape as the input divided by block_size

0 commit comments

Comments
 (0)