Skip to content

Commit b9f3c40

Browse files
committed
check ipex before MatMul8bitFp
Signed-off-by: jiqing-feng <[email protected]>
1 parent f44d4a2 commit b9f3c40

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import deprecated
99

1010
import bitsandbytes.functional as F
11+
from bitsandbytes.functional import ipex_cpu, ipex_xpu
1112

1213
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
1314
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
@@ -425,8 +426,9 @@ def matmul(
425426
if threshold > 0.0:
426427
state.threshold = threshold
427428
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
428-
if A.device.type in ("cpu", "xpu") and state.is_training:
429-
return MatMul8bitFp.apply(A, B, out, bias, state)
429+
if state.is_training:
430+
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
431+
return MatMul8bitFp.apply(A, B, out, bias, state)
430432
return MatMul8bitLt.apply(A, B, out, bias, state)
431433

432434

0 commit comments

Comments
 (0)