Skip to content

Commit f607a7f

Browse files
committed
For delivery, pin torchao to 0.11, also improve error checking and change how matrices are scaled in fp8 matmul
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent 39db419 commit f607a7f

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

fms_mo/aiu_addons/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Local
2+
from fms_mo.prep import available_packages
3+
4+
15
def _infer_quantization_config(quant_config: dict) -> dict | None:
26
"""Construct linear_config dictionary carrying FP8 configuration for FMS.
37
@@ -20,6 +24,10 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
2024
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
2125
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
2226
):
27+
if not available_packages["torchao"]:
28+
raise ImportError(
29+
"You need torchao installed to load FP8 checkpoints in FMS"
30+
)
2331
# First, import required FP8 linear classes from fms-mo
2432
# Local
2533
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,9 @@ def _math_fp8_compute_op(
220220
.to(dtype=orig_dtype)
221221
.transpose(-2, -1)
222222
)
223-
attn_weight = query @ key_t
224-
attn_weight *= scale_factor
223+
attn_weight = (query * math.sqrt(scale_factor)) @ (
224+
key_t * math.sqrt(scale_factor)
225+
)
225226
attn_weight += attn_bias
226227
attn_weight = torch.softmax(attn_weight, dim=-1)
227228
attn_weight = torch.dropout(attn_weight, p_dropout, train=True)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535

3636
[project.optional-dependencies]
3737
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
38-
fp8 = ["llmcompressor", "torchao>=0.11,<=0.12"]
38+
fp8 = ["llmcompressor", "torchao==0.11"]
3939
gptq = ["Cython", "gptqmodel>=1.7.3"]
4040
mx = ["microxcaling>=1.1"]
4141
opt = ["fms-model-optimizer[fp8, gptq, mx]"]

0 commit comments

Comments
 (0)