|
13 | 13 | # limitations under the License. |
14 | 14 | """Torch registration of FP8xFP8 operation for attention BMMs.""" |
15 | 15 |
|
| 16 | +# Standard |
| 17 | +from typing import Optional |
| 18 | + |
16 | 19 | # Third Party |
| 20 | +from packaging.version import Version |
17 | 21 | from torch import Tensor |
18 | 22 | import torch |
19 | 23 | import torch.nn.functional as F |
|
26 | 30 | # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 |
27 | 31 |
|
28 | 32 |
|
| 33 | +if Version(torch.__version__) <= Version("2.7"): |
| 34 | + # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set, |
| 35 | + # while for earlier versions we need a custom definition |
| 36 | + def _scaled_mm_cpu_out( |
| 37 | + mat1: Tensor, |
| 38 | + mat2: Tensor, |
| 39 | + scale1: Tensor, |
| 40 | + scale2: Tensor, |
| 41 | + bias: Optional[Tensor] = None, |
| 42 | + scale_result: Optional[Tensor] = None, |
| 43 | + out_dtype: Optional[torch.dtype] = None, |
| 44 | + use_fast_accum: bool = False, |
| 45 | + *, |
| 46 | + out: Optional[Tensor] = None, |
| 47 | + ) -> Tensor: |
| 48 | + if out_dtype is None: |
| 49 | + out_dtype = torch.float32 |
| 50 | + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) |
| 51 | + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) |
| 52 | + |
| 53 | + if bias is not None: |
| 54 | + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) |
| 55 | + else: |
| 56 | + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) |
| 57 | + |
| 58 | + if out is not None: |
| 59 | + out.copy_(ret) |
| 60 | + return out |
| 61 | + return ret |
| 62 | + |
| 63 | + torch.library.register_kernel( |
| 64 | + torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out |
| 65 | + ) |
| 66 | + |
| 67 | + @torch.library.register_kernel("aten::_scaled_mm", "cpu") |
| 68 | + def _scaled_mm_cpu( |
| 69 | + mat1: Tensor, |
| 70 | + mat2: Tensor, |
| 71 | + scale1: Tensor, |
| 72 | + scale2: Tensor, |
| 73 | + bias: Optional[Tensor] = None, |
| 74 | + scale_result: Optional[Tensor] = None, |
| 75 | + out_dtype: Optional[torch.dtype] = None, |
| 76 | + use_fast_accum: bool = False, |
| 77 | + ) -> Tensor: |
| 78 | + return _scaled_mm_cpu_out( |
| 79 | + mat1, |
| 80 | + mat2, |
| 81 | + scale1, |
| 82 | + scale2, |
| 83 | + bias, |
| 84 | + scale_result, |
| 85 | + out_dtype, |
| 86 | + use_fast_accum, |
| 87 | + out=None, |
| 88 | + ) |
| 89 | + |
| 90 | + |
29 | 91 | @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) |
30 | 92 | def spyre_scaled_bmm( |
31 | 93 | mat1: Tensor, |
|
0 commit comments