44import torch
55
66from ..._ops import register_kernel
7- from ..utils import ipex_xpu
7+ from ..utils import ipex_xpu , triton_available
88
9- # With default torch, error:
10- # NotImplementedError: The operator 'aten::_int_mm' for XPU
9+ # _int_mm is available in torch starting from 2.7 version,
10+ # but currently it's don't have xpu implementation.
1111if ipex_xpu and torch .__version__ >= (2 , 7 ):
1212
1313 @register_kernel ("bitsandbytes::int8_linear_matmul" , "xpu" )
@@ -18,6 +18,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
1818 ).reshape (* A .shape [:- 1 ], B .shape [0 ])
1919
2020
21+ # IPEX should be faster for xpu, so at first checking if it is available.
2122if ipex_xpu :
2223
2324 @register_kernel ("bitsandbytes::dequantize_nf4_ipex" , "xpu" )
@@ -52,23 +53,15 @@ def _(
5253 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
5354
5455 return out .reshape (shape )
55- else :
56- # IPEX should be faster for xpu, so at first checking if it is available.
57- try :
58- from ..triton import ops as triton_ops
59-
60- triton_available = True
61- except ImportError as e :
62- print ("Import error:" , e )
63- triton_available = False
56+ elif triton_available :
57+ from ..triton import ops as triton_ops
6458
65- if triton_available :
66- register_kernel ("bitsandbytes::quantize_blockwise" , "xpu" )(triton_ops .quantize_blockwise )
67- register_kernel ("bitsandbytes::dequantize_blockwise.out" , "xpu" )(triton_ops .dequantize_blockwise_inplace )
68- register_kernel ("bitsandbytes::dequantize_blockwise" , "xpu" )(triton_ops .dequantize_blockwise )
69- register_kernel ("bitsandbytes::quantize_4bit" , "xpu" )(triton_ops .quantize_4bit )
70- register_kernel ("bitsandbytes::dequantize_4bit.out" , "xpu" )(triton_ops .dequantize_4bit_inplace )
71- register_kernel ("bitsandbytes::dequantize_4bit" , "xpu" )(triton_ops .dequantize_4bit )
72- register_kernel ("bitsandbytes::gemv_4bit" , "xpu" )(triton_ops .gemv_4bit )
73- else :
74- warnings .warn ("XPU available, but trtion package is missing." )
59+ register_kernel ("bitsandbytes::quantize_blockwise" , "xpu" )(triton_ops .quantize_blockwise )
60+ register_kernel ("bitsandbytes::dequantize_blockwise.out" , "xpu" )(triton_ops .dequantize_blockwise_inplace )
61+ register_kernel ("bitsandbytes::dequantize_blockwise" , "xpu" )(triton_ops .dequantize_blockwise )
62+ register_kernel ("bitsandbytes::quantize_4bit" , "xpu" )(triton_ops .quantize_4bit )
63+ register_kernel ("bitsandbytes::dequantize_4bit.out" , "xpu" )(triton_ops .dequantize_4bit_inplace )
64+ register_kernel ("bitsandbytes::dequantize_4bit" , "xpu" )(triton_ops .dequantize_4bit )
65+ register_kernel ("bitsandbytes::gemv_4bit" , "xpu" )(triton_ops .gemv_4bit )
66+ else :
67+ warnings .warn ("XPU available but no ipex or triton packages found." )
0 commit comments