Skip to content

Commit 1991d09

Browse files
committed
Test fix for FP8 matmul
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent acdb545 commit 1991d09

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def _scaled_mm_cpu_out(
6060
return ret
6161

6262

63-
torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)
64-
65-
66-
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
6763
def _scaled_mm_cpu(
6864
mat1: Tensor,
6965
mat2: Tensor,
@@ -87,6 +83,19 @@ def _scaled_mm_cpu(
8783
)
8884

8985

86+
if torch.__version__ >= "2.8":
87+
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
88+
torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out
89+
torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu
90+
else:
91+
torch.library.register_kernel(
92+
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
93+
)
94+
torch.library.register_kernel(
95+
torch.ops.aten._scaled_mm.default, "cpu", _scaled_mm_cpu
96+
)
97+
98+
9099
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
91100
def spyre_scaled_bmm(
92101
mat1: Tensor,
@@ -114,7 +123,7 @@ def spyre_scaled_bmm(
114123
device=mat1.device,
115124
)
116125
for b_idx in range(mat1.shape[0]):
117-
out[b_idx] = torch._scaled_mm(
126+
out[b_idx] = _scaled_mm_cpu_out(
118127
mat1[b_idx],
119128
mat2[b_idx],
120129
scale1,

0 commit comments

Comments
 (0)