File tree Expand file tree Collapse file tree 1 file changed +28
-0
lines changed
Expand file tree Collapse file tree 1 file changed +28
-0
lines changed Original file line number Diff line number Diff line change 1313# limitations under the License.
1414"""Torch registration of FP8xFP8 operation for attention BMMs."""
1515
16+ # Standard
17+ from typing import Optional
18+
1619# Third Party
1720from torch import Tensor
1821import torch
2629# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
2730
2831
32+ aten = torch .ops .aten
33+ DispatchKey = torch ._C .DispatchKey # type: ignore[attr-defined]
34+
35+
36+ @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
37+ def _scaled_mm_cpu (
38+ mat1 : Tensor ,
39+ mat2 : Tensor ,
40+ scale1 : Tensor ,
41+ scale2 : Tensor ,
42+ bias : Optional [Tensor ] = None ,
43+ scale_result : Optional [Tensor ] = None ,
44+ out_dtype : Optional [torch .dtype ] = None ,
45+ use_fast_accum : bool = False ,
46+ ) -> Tensor :
47+ if out_dtype is None :
48+ out_dtype = torch .float32
49+ mat1 = (mat1 .to (dtype = out_dtype ) * scale1 ).to (dtype = out_dtype )
50+ mat2 = (mat2 .to (dtype = out_dtype ) * scale2 ).to (dtype = out_dtype )
51+
52+ if bias is not None :
53+ return torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
54+ return torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
55+
56+
2957@torch .library .custom_op ("spyre::scaled_bmm" , mutates_args = ())
3058def spyre_scaled_bmm (
3159 mat1 : Tensor ,
You can’t perform that action at this time.
0 commit comments