File tree Expand file tree Collapse file tree 1 file changed +14
-5
lines changed
Expand file tree Collapse file tree 1 file changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -60,10 +60,6 @@ def _scaled_mm_cpu_out(
6060 return ret
6161
6262
63- torch .library .register_kernel (torch .ops .aten ._scaled_mm .out , "cpu" , _scaled_mm_cpu_out )
64-
65-
66- @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
6763def _scaled_mm_cpu (
6864 mat1 : Tensor ,
6965 mat2 : Tensor ,
@@ -87,6 +83,19 @@ def _scaled_mm_cpu(
8783 )
8884
8985
86+ if torch .__version__ >= "2.8" :
87+ DispatchKey = torch ._C .DispatchKey # type: ignore[attr-defined]
88+ torch .ops .aten ._scaled_mm .out .py_kernels [DispatchKey .CPU ] = _scaled_mm_cpu_out
89+ torch .ops .aten ._scaled_mm .default .py_kernels [DispatchKey .CPU ] = _scaled_mm_cpu
90+ else :
91+ torch .library .register_kernel (
92+ torch .ops .aten ._scaled_mm .out , "cpu" , _scaled_mm_cpu_out
93+ )
94+ torch .library .register_kernel (
95+ torch .ops .aten ._scaled_mm .default , "cpu" , _scaled_mm_cpu
96+ )
97+
98+
9099@torch .library .custom_op ("spyre::scaled_bmm" , mutates_args = ())
91100def spyre_scaled_bmm (
92101 mat1 : Tensor ,
@@ -114,7 +123,7 @@ def spyre_scaled_bmm(
114123 device = mat1 .device ,
115124 )
116125 for b_idx in range (mat1 .shape [0 ]):
117- out [b_idx ] = torch . _scaled_mm (
126+ out [b_idx ] = _scaled_mm_cpu_out (
118127 mat1 [b_idx ],
119128 mat2 [b_idx ],
120129 scale1 ,
You can’t perform that action at this time.
0 commit comments