Skip to content

Commit b0982fe

Browse files
committed
ipex version check
1 parent 886213b commit b0982fe

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

bitsandbytes/backends/xpu.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
int8_linear_matmul_impl,
1313
int8_mm_dequant_impl,
1414
quantize_4bit_impl,
15+
_ipex_xpu_version_prereq
1516
)
1617
try:
1718
import intel_extension_for_pytorch as ipex
@@ -23,7 +24,7 @@
2324

2425

2526
str2optimizer8bit_blockwise = {}
26-
if ipex_xpu is not None:
27+
if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7):
2728
str2optimizer8bit_blockwise = {
2829
"adam": (
2930
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32,
@@ -205,8 +206,8 @@ def dequantize_blockwise(
205206
blocksize: int = 4096,
206207
nested=False,
207208
) -> torch.Tensor:
208-
if ipex_xpu is None:
209-
raise RuntimeError("Please install intel_extension_for_ipex for 8bit optimizer backend on XPU device.")
209+
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
210+
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")
210211

211212
# void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
212213
if out.dtype == torch.float16:
@@ -253,8 +254,8 @@ def optimizer_update_8bit_blockwise(
253254
skip_zeros=False,
254255
) -> None:
255256
optim_func = None
256-
if ipex_xpu is None:
257-
raise RuntimeError("Please install intel_extension_for_ipex for 8bit optimizer backend on XPU device.")
257+
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
258+
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")
258259

259260
assert_on_xpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
260261

0 commit comments

Comments
 (0)