Skip to content

Commit e9874ef

Browse files
committed
Merge branch 'main' into fast_loading
2 parents b137e0c + 207eb06 commit e9874ef

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

fms_mo/aiu_addons/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Local
2+
from fms_mo.prep import available_packages
3+
4+
15
def _infer_quantization_config(quant_config: dict) -> dict | None:
26
"""Construct linear_config dictionary carrying FP8 configuration for FMS.
37
@@ -20,9 +24,14 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
2024
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
2125
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
2226
):
27+
if not available_packages["torchao"]:
28+
raise ImportError(
29+
"You need torchao installed to load FP8 checkpoints in FMS"
30+
)
2331
# First, import required FP8 linear classes from fms-mo
2432
# Local
2533
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
34+
import fms_mo.aiu_addons.fp8.fp8_attn # pylint: disable=unused-import
2635
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import
2736

2837
# This is used by get_linear to decide whether a linear layer

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# Third Party
3030
from fms.modules.attention import (
3131
AttentionKwargs,
32+
_sdpa_compute_op,
3233
_sdpa_update_attn_kwargs,
3334
register_attention_op,
3435
)
@@ -219,8 +220,9 @@ def _math_fp8_compute_op(
219220
.to(dtype=orig_dtype)
220221
.transpose(-2, -1)
221222
)
222-
attn_weight = query @ key_t
223-
attn_weight *= scale_factor
223+
attn_weight = (query * math.sqrt(scale_factor)) @ (
224+
key_t * math.sqrt(scale_factor)
225+
)
224226
attn_weight += attn_bias
225227
attn_weight = torch.softmax(attn_weight, dim=-1)
226228
attn_weight = torch.dropout(attn_weight, p_dropout, train=True)
@@ -340,7 +342,7 @@ def __spyre_scaled_paged_validate_attn_kwargs_op(
340342
register_attention_op(
341343
"spyre_paged_attn_fp8",
342344
_spyre_scaled_paged_store_op,
343-
compute_op=_math_fp8_compute_op,
345+
compute_op=_sdpa_compute_op,
344346
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None)
345347
is None,
346348
compute_decode_op=_spyre_scaled_paged_compute_op,

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,12 @@ def shard_fp8_linear(
321321
sharding | param | shard | dim |
322322
----------+----------------+-------+-----|
323323
colwise | weight | Y | 0 |
324-
| weight_scale | N | - |
324+
| weight_scale | Y/N | 0/- |
325325
| input_scale | N | - |
326326
| bias | Y | 0 |
327327
----------+----------------+-------+-----|
328328
rowwise | weight | Y | 1 |
329-
| weight_scale | Y/N | 0/- |
329+
| weight_scale | N | - |
330330
| input_scale | Y/N | 0/- |
331331
| bias | 0 | - |
332332
"""
@@ -339,7 +339,7 @@ def shard_fp8_linear(
339339
]
340340
# Scales are per-row or per-tensor
341341
# Only sharding needed when row parallel and per-row
342-
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1
342+
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 0
343343
params: dict[str, LinearParameterShardingInfo] = {
344344
"weight": LinearParameterShardingInfo(
345345
module_info.sharding_dim, ShardType.SHARD

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dynamic = ["version"]
2424
dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
2626
"accelerate>=0.20.3,!=0.34,<1.10",
27-
"transformers>=4.45,<4.54",
27+
"transformers>=4.45,<4.56",
2828
"torch>=2.2.0,<2.8",
2929
"tqdm>=4.66.2,<5.0",
3030
"datasets>=3.0.0,<5.0",
@@ -35,14 +35,14 @@ dependencies = [
3535

3636
[project.optional-dependencies]
3737
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
38-
fp8 = ["llmcompressor", "torchao>=0.11,<=0.12"]
38+
fp8 = ["llmcompressor", "torchao==0.11"]
3939
gptq = ["Cython", "gptqmodel>=1.7.3"]
4040
mx = ["microxcaling>=1.1"]
4141
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
4242
aiu = ["ibm-fms>=0.0.8"]
4343
torchvision = ["torchvision>=0.17"]
4444
flash-attn = ["flash-attn>=2.5.3,<3.0"]
45-
triton = ["triton>=3.0,<3.4"]
45+
triton = ["triton>=3.0,<3.5"]
4646
visualize = ["matplotlib", "graphviz", "pygraphviz", "tensorboard", "notebook"]
4747
dev = ["pre-commit>=3.0.4,<5.0"]
4848
test = ["pytest", "pillow"]

0 commit comments

Comments
 (0)