Skip to content

Commit c5a55fc

Browse files
committed
Add CPU fallback for scaled_mm
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent 43372e4 commit c5a55fc

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
"""Torch registration of FP8xFP8 operation for attention BMMs."""
1515

16+
# Standard
17+
from typing import Optional
18+
1619
# Third Party
1720
from torch import Tensor
1821
import torch
@@ -26,6 +29,31 @@
2629
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
2730

2831

32+
aten = torch.ops.aten
33+
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
34+
35+
36+
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
37+
def _scaled_mm_cpu(
38+
mat1: Tensor,
39+
mat2: Tensor,
40+
scale1: Tensor,
41+
scale2: Tensor,
42+
bias: Optional[Tensor] = None,
43+
scale_result: Optional[Tensor] = None,
44+
out_dtype: Optional[torch.dtype] = None,
45+
use_fast_accum: bool = False,
46+
) -> Tensor:
47+
if out_dtype is None:
48+
out_dtype = torch.float32
49+
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
50+
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
51+
52+
if bias is not None:
53+
return torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
54+
return torch.mm(mat1, mat2).to(dtype=out_dtype)
55+
56+
2957
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
3058
def spyre_scaled_bmm(
3159
mat1: Tensor,

0 commit comments

Comments
 (0)