Skip to content

Commit d66f93d

Browse files
authored
Merge pull request #9 from xiaolil1/jiqing
fix xpu log
2 parents 041b442 + b3db4bf commit d66f93d

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

bitsandbytes/backends/xpu/ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Sequence
22
import ctypes as ct
3-
import warnings
3+
import logging
44

55
import torch
66

@@ -10,6 +10,8 @@
1010
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
1111
from ..utils import triton_available
1212

13+
logger = logging.getLogger(__name__)
14+
1315

1416
def _dequantize_4bit_impl(
1517
A: torch.Tensor,
@@ -135,6 +137,7 @@ def _gemv_4bit_impl(
135137

136138
# SYCL should be faster for xpu, so at first checking if it is available.
137139
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
140+
logger.info("Loading sycl bitsandbytes kernels for XPU")
138141

139142
@register_kernel("bitsandbytes::dequantize_4bit", "xpu")
140143
def _(
@@ -201,6 +204,7 @@ def _(
201204
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
202205
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
203206
elif triton_available:
207+
logger.info("Loading triton bitsandbytes kernels for XPU")
204208
from ..triton import ops as triton_ops
205209

206210
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
@@ -211,6 +215,4 @@ def _(
211215
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
212216
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
213217
else:
214-
warnings.warn(
215-
"XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation."
216-
)
218+
logger.warning("Loading pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")

bitsandbytes/cextension.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,27 @@ def get_native_library() -> BNBNativeLibrary:
303303

304304
ROCM_GPU_ARCH = get_rocm_gpu_arch()
305305

306-
try:
307-
if torch.version.hip:
308-
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
309-
else:
310-
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
306+
HIP_ENVIRONMENT = False
307+
BNB_BACKEND = "CPU"
308+
if torch.version.hip:
309+
HIP_ENVIRONMENT = True
310+
BNB_BACKEND = "ROCm"
311+
elif torch.cuda.is_available():
312+
BNB_BACKEND = "CUDA"
313+
elif torch._C._has_xpu:
314+
BNB_BACKEND = "XPU"
311315

316+
try:
312317
lib = get_native_library()
313318
except Exception as e:
314-
error_msg = str(e)
315-
logger.error(
316-
f"bitsandbytes library load error: {error_msg}",
317-
exc_info=True,
318-
)
319+
if BNB_BACKEND in ("CPU", "XPU"):
320+
lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.")
321+
else:
322+
error_msg = str(e)
323+
logger.error(
324+
f"bitsandbytes library load error: {error_msg}",
325+
exc_info=True,
326+
)
319327

320-
# create a mock with error messaging as fallback
321-
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
328+
# create a mock with error messaging as fallback
329+
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)

0 commit comments

Comments
 (0)