Skip to content

Commit f61c561

Browse files
committed
Re-enable custom op for pt<=2.7
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent eaad33b commit f61c561

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
# limitations under the License.
1414
"""Torch registration of FP8xFP8 operation for attention BMMs."""
1515

16+
# Standard
17+
from typing import Optional
18+
1619
# Third Party
20+
from packaging.version import Version
1721
from torch import Tensor
1822
import torch
1923
import torch.nn.functional as F
@@ -26,6 +30,64 @@
2630
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
2731

2832

33+
if Version(torch.__version__) <= Version("2.7"):
34+
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
35+
# while for earlier versions we need a custom definition
36+
def _scaled_mm_cpu_out(
37+
mat1: Tensor,
38+
mat2: Tensor,
39+
scale1: Tensor,
40+
scale2: Tensor,
41+
bias: Optional[Tensor] = None,
42+
scale_result: Optional[Tensor] = None,
43+
out_dtype: Optional[torch.dtype] = None,
44+
use_fast_accum: bool = False,
45+
*,
46+
out: Optional[Tensor] = None,
47+
) -> Tensor:
48+
if out_dtype is None:
49+
out_dtype = torch.float32
50+
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
51+
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
52+
53+
if bias is not None:
54+
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
55+
else:
56+
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
57+
58+
if out is not None:
59+
out.copy_(ret)
60+
return out
61+
return ret
62+
63+
torch.library.register_kernel(
64+
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
65+
)
66+
67+
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
68+
def _scaled_mm_cpu(
69+
mat1: Tensor,
70+
mat2: Tensor,
71+
scale1: Tensor,
72+
scale2: Tensor,
73+
bias: Optional[Tensor] = None,
74+
scale_result: Optional[Tensor] = None,
75+
out_dtype: Optional[torch.dtype] = None,
76+
use_fast_accum: bool = False,
77+
) -> Tensor:
78+
return _scaled_mm_cpu_out(
79+
mat1,
80+
mat2,
81+
scale1,
82+
scale2,
83+
bias,
84+
scale_result,
85+
out_dtype,
86+
use_fast_accum,
87+
out=None,
88+
)
89+
90+
2991
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
3092
def spyre_scaled_bmm(
3193
mat1: Tensor,

0 commit comments

Comments
 (0)