Skip to content

Commit eaad33b

Browse files
committed
Remove custom scaled bmm op on cpu and fix fp8 test
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 47b8716 commit eaad33b

File tree

2 files changed

+3
-61
lines changed

2 files changed

+3
-61
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414
"""Torch registration of FP8xFP8 operation for attention BMMs."""
1515

16-
# Standard
17-
from typing import Optional
18-
1916
# Third Party
2017
from torch import Tensor
2118
import torch
@@ -29,61 +26,6 @@
2926
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3027

3128

32-
def _scaled_mm_cpu_out(
33-
mat1: Tensor,
34-
mat2: Tensor,
35-
scale1: Tensor,
36-
scale2: Tensor,
37-
bias: Optional[Tensor] = None,
38-
scale_result: Optional[Tensor] = None,
39-
out_dtype: Optional[torch.dtype] = None,
40-
use_fast_accum: bool = False,
41-
*,
42-
out: Optional[Tensor] = None,
43-
) -> Tensor:
44-
if out_dtype is None:
45-
out_dtype = torch.float32
46-
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
47-
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
48-
49-
if bias is not None:
50-
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
51-
else:
52-
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
53-
54-
if out is not None:
55-
out.copy_(ret)
56-
return out
57-
return ret
58-
59-
60-
torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)
61-
62-
63-
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
64-
def _scaled_mm_cpu(
65-
mat1: Tensor,
66-
mat2: Tensor,
67-
scale1: Tensor,
68-
scale2: Tensor,
69-
bias: Optional[Tensor] = None,
70-
scale_result: Optional[Tensor] = None,
71-
out_dtype: Optional[torch.dtype] = None,
72-
use_fast_accum: bool = False,
73-
) -> Tensor:
74-
return _scaled_mm_cpu_out(
75-
mat1,
76-
mat2,
77-
scale1,
78-
scale2,
79-
bias,
80-
scale_result,
81-
out_dtype,
82-
use_fast_accum,
83-
out=None,
84-
)
85-
86-
8729
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
8830
def spyre_scaled_bmm(
8931
mat1: Tensor,

tests/aiu_addons/test_fp8_addon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def test_fp8_op() -> None:
5151
# Local
5252
from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op
5353

54-
query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
55-
key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
56-
value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
54+
query = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
55+
key = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
56+
value = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
5757

5858
out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None)
5959
assert out.size() == query.size()

0 commit comments

Comments
 (0)