Skip to content

Commit 5d3a178

Browse files
committed
Add version limits for torchao, ensure compat with 0.12 + AIU
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent 9ee258b commit 5d3a178

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,46 @@
3838
shard_base_linear,
3939
)
4040
from fms.modules.tp import ShardType, TPModule
41+
42+
# Register decomps for torchao >= 0.12 + AIU.
43+
# This import only succeeds if torchao is 0.12 or higher
44+
try:
45+
# Third Party
46+
from torchao.quantization.quant_primitives import _expand_scale_to_tensor_shape
47+
48+
# This function is copied from _quantize_affine_float8
49+
# in torchao.quantization.quant_primitives, but removing
50+
# the wrapping that turns it into a custom pytorch op
51+
def _quantize_affine_float8_custom(
52+
tensor: torch.Tensor,
53+
scale: torch.Tensor,
54+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
55+
) -> torch.Tensor:
56+
"""
57+
Quantizes the high precision floating point tensor
58+
to a float8 tensor, using the given scaling factor.
59+
"""
60+
tensor_fp32 = tensor.to(torch.float32)
61+
62+
# Expand scale to match tensor dimensions for block-wise quantization
63+
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)
64+
65+
tensor_scaled = tensor_fp32 / scale_expanded
66+
max_value = torch.finfo(float8_dtype).max
67+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
68+
fp8_tensor = tensor_clamped.to(float8_dtype)
69+
return fp8_tensor
70+
71+
quant_lib = torch.library.Library("torchao", "FRAGMENT")
72+
quant_lib.impl(
73+
"quantize_affine_float8",
74+
_quantize_affine_float8_custom,
75+
"CompositeImplicitAutograd",
76+
)
77+
except ImportError:
78+
pass
79+
80+
# Third Party
4181
from torchao.dtypes.affine_quantized_tensor import (
4282
AffineQuantizedTensor,
4383
to_affine_quantized_floatx,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535

3636
[project.optional-dependencies]
3737
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
38-
fp8 = ["llmcompressor", "torchao"]
38+
fp8 = ["llmcompressor", "torchao>=0.11,<=0.12"]
3939
gptq = ["Cython", "gptqmodel>=1.7.3"]
4040
mx = ["microxcaling>=1.1"]
4141
opt = ["fms-model-optimizer[fp8, gptq, mx]"]

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ deps =
3434
pylint>=2.16.2,<4.0
3535
pylint-pydantic
3636
ibm-fms
37-
torchao
37+
torchao>=0.11,<=0.12
3838
commands =
3939
{basepython} -m pylint --load-plugins pylint_pydantic fms_mo/ tests/
4040

0 commit comments

Comments
 (0)