Skip to content

Commit 207eb06

Browse files
authored
Merge pull request #176 from foundation-model-stack/fp8-tp-fixes
fix: FP8 TP fixes
2 parents e83f55c + 3a006dc commit 207eb06

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

fms_mo/aiu_addons/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
3131
# First, import required FP8 linear classes from fms-mo
3232
# Local
3333
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
34+
import fms_mo.aiu_addons.fp8.fp8_attn # pylint: disable=unused-import
3435
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import
3536

3637
# This is used by get_linear to decide whether a linear layer

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,12 @@ def shard_fp8_linear(
321321
sharding | param | shard | dim |
322322
----------+----------------+-------+-----|
323323
colwise | weight | Y | 0 |
324-
| weight_scale | N | - |
324+
| weight_scale | Y/N | 0/- |
325325
| input_scale | N | - |
326326
| bias | Y | 0 |
327327
----------+----------------+-------+-----|
328328
rowwise | weight | Y | 1 |
329-
| weight_scale | Y/N | 0/- |
329+
| weight_scale | N | - |
330330
| input_scale | Y/N | 0/- |
331331
| bias | 0 | - |
332332
"""
@@ -339,7 +339,7 @@ def shard_fp8_linear(
339339
]
340340
# Scales are per-row or per-tensor
341341
# Only sharding needed when row parallel and per-row
342-
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1
342+
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 0
343343
params: dict[str, LinearParameterShardingInfo] = {
344344
"weight": LinearParameterShardingInfo(
345345
module_info.sharding_dim, ShardType.SHARD

0 commit comments

Comments
 (0)