66from ..._ops import register_kernel
77from ..utils import ipex_xpu , triton_available
88
9- # With default torch, error:
10- # NotImplementedError: The operator 'aten::_int_mm' for XPU
9+ # _int_mm is available in torch starting from 2.7 version,
10+ # but currently it's don't have xpu implementation.
1111if ipex_xpu and torch .__version__ >= (2 , 7 ):
1212
1313 @register_kernel ("bitsandbytes::int8_linear_matmul" , "xpu" )
@@ -18,6 +18,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
1818 ).reshape (* A .shape [:- 1 ], B .shape [0 ])
1919
2020
21+ # IPEX should be faster for xpu, so at first checking if it is available.
2122if ipex_xpu :
2223
2324 @register_kernel ("bitsandbytes::dequantize_nf4_ipex" , "xpu" )
5354
5455 return out .reshape (shape )
5556elif triton_available :
56- # IPEX should be faster for xpu, so at first checking if it is available.
5757 from ..triton import ops as triton_ops
5858
5959 register_kernel ("bitsandbytes::quantize_blockwise" , "xpu" )(triton_ops .quantize_blockwise )
6464 register_kernel ("bitsandbytes::dequantize_4bit" , "xpu" )(triton_ops .dequantize_4bit )
6565 register_kernel ("bitsandbytes::gemv_4bit" , "xpu" )(triton_ops .gemv_4bit )
6666else :
67- warnings .warn ("XPU available, but nor ipex or trtion package is found." )
67+ warnings .warn ("XPU available but no ipex or triton packages found." )
0 commit comments