|
13 | 13 | # limitations under the License. |
14 | 14 | """Torch registration of FP8xFP8 operation for attention BMMs.""" |
15 | 15 |
|
16 | | -# Standard |
17 | | -from typing import Optional |
18 | | - |
19 | 16 | # Third Party |
20 | 17 | from torch import Tensor |
21 | 18 | import torch |
|
29 | 26 | # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 |
30 | 27 |
|
31 | 28 |
|
32 | | -def _scaled_mm_cpu_out( |
33 | | - mat1: Tensor, |
34 | | - mat2: Tensor, |
35 | | - scale1: Tensor, |
36 | | - scale2: Tensor, |
37 | | - bias: Optional[Tensor] = None, |
38 | | - scale_result: Optional[Tensor] = None, |
39 | | - out_dtype: Optional[torch.dtype] = None, |
40 | | - use_fast_accum: bool = False, |
41 | | - *, |
42 | | - out: Optional[Tensor] = None, |
43 | | -) -> Tensor: |
44 | | - if out_dtype is None: |
45 | | - out_dtype = torch.float32 |
46 | | - mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) |
47 | | - mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) |
48 | | - |
49 | | - if bias is not None: |
50 | | - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) |
51 | | - else: |
52 | | - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) |
53 | | - |
54 | | - if out is not None: |
55 | | - out.copy_(ret) |
56 | | - return out |
57 | | - return ret |
58 | | - |
59 | | - |
60 | | -torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) |
61 | | - |
62 | | - |
63 | | -@torch.library.register_kernel("aten::_scaled_mm", "cpu") |
64 | | -def _scaled_mm_cpu( |
65 | | - mat1: Tensor, |
66 | | - mat2: Tensor, |
67 | | - scale1: Tensor, |
68 | | - scale2: Tensor, |
69 | | - bias: Optional[Tensor] = None, |
70 | | - scale_result: Optional[Tensor] = None, |
71 | | - out_dtype: Optional[torch.dtype] = None, |
72 | | - use_fast_accum: bool = False, |
73 | | -) -> Tensor: |
74 | | - return _scaled_mm_cpu_out( |
75 | | - mat1, |
76 | | - mat2, |
77 | | - scale1, |
78 | | - scale2, |
79 | | - bias, |
80 | | - scale_result, |
81 | | - out_dtype, |
82 | | - use_fast_accum, |
83 | | - out=None, |
84 | | - ) |
85 | | - |
86 | | - |
87 | 29 | @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) |
88 | 30 | def spyre_scaled_bmm( |
89 | 31 | mat1: Tensor, |
|
0 commit comments