Skip to content

Commit c931ad7

Browse files
committed
rename fp8 attention
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 496bf44 commit c931ad7

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

fms_mo/aiu_addons/fp8/fp8_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from fms.utils.config import ModelConfig
2323

2424
# pylint: disable=unused-argument
25-
# Retaining kwargs input arguments for consistency.
25+
# Retaining kwargs input arguments for consistency with other adapter steps.
2626

2727

2828
# NOTE: this adapter step must be registered before the adapter that uses it (such as

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,31 @@
3333
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3434

3535

36-
### FP8 linear layers
36+
# Gated torchao imports for FP8 implementation
3737
if find_spec("torchao"):
3838
TORCHAO_INSTALLED = True
3939

4040
# Third Party
41-
from torchao.dtypes.affine_quantized_tensor import ( # type: ignore
41+
from torchao.dtypes.affine_quantized_tensor import (
4242
AffineQuantizedTensor,
4343
to_affine_quantized_floatx,
4444
to_affine_quantized_floatx_static,
4545
)
46-
from torchao.dtypes.floatx.float8_layout import ( # type: ignore
46+
from torchao.dtypes.floatx.float8_layout import (
4747
Float8AQTTensorImpl,
4848
Float8Layout,
4949
Float8MMConfig,
5050
preprocess_data,
5151
preprocess_scale,
5252
)
53-
from torchao.dtypes.utils import get_out_shape # type: ignore
54-
from torchao.float8.inference import ( # type: ignore
53+
from torchao.dtypes.utils import get_out_shape
54+
from torchao.float8.inference import (
5555
_is_rowwise_scaled,
5656
addmm_float8_unwrapped_inference,
5757
)
58-
from torchao.quantization.granularity import PerRow, PerTensor # type: ignore
59-
from torchao.quantization.observer import get_block_size # type: ignore
60-
from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore
58+
from torchao.quantization.granularity import PerRow, PerTensor
59+
from torchao.quantization.observer import get_block_size
60+
from torchao.quantization.quant_primitives import ZeroPointDomain
6161
else:
6262
TORCHAO_INSTALLED = False
6363

@@ -177,7 +177,8 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor":
177177
)
178178

179179
def forward(self, x: torch.Tensor) -> torch.Tensor:
180-
"""If input quantization is active, compute FP8xFP8 addmm."""
180+
"""If input quantization is active, compute FP8xFP8 addmm leveraging torchao
181+
functionalities. Otherwise compute non-quantized addmm."""
181182

182183
# fp8 weight tensor for torchao
183184
qweight: AffineQuantizedTensor = self._construct_qweight_structure()
@@ -282,6 +283,7 @@ def shard_fp8_linear(
282283
| input_scale | Y/N | 0/- |
283284
| bias | 0 | - |
284285
"""
286+
285287
param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
286288
for module_name, module_info in module_sharding_info.items():
287289
linear_mod: torch.nn.Module = module_info.linear_module

0 commit comments

Comments
 (0)