Skip to content

Commit c0b1a62

Browse files
jiqing-fengpnunna93akxTitus-von-Koellermatthewdouglas
authored andcommitted
Enable XPU and optimize cpu/xpu op (bitsandbytes-foundation#1418)
* enable new ipex API ipex weight is 4D so we cannot transpose fix dequant check require grad * use ipex op in backward * enable backward * Multi backend refactor (#8) * AMD: Clarify diagnostic messages; free up disk space for CI build * Add build job for rocm * Add rocm build script * Copy shared obj file into output_dir * upload build artifacts and enable wheels build * Remove cuda build temporarily * Add ROCm version to .so filename * Add rocm_version to whls build * Revert "Remove cuda build temporarily" This reverts commit 1413c5f. * Add rocm_version env var * Remove thrush header files * Print node info * print cuda node info * Revert "print cuda node info" This reverts commit cdb209a. * Revert "Print node info" This reverts commit 7e9a65c. * Add rocm arch to compile command * Rename .so files to rocm * Update default gpu arch * Skip cpu based igemmlt int tests on ROCm * Update Documentation * Update upstream repo name * Update docs * Update string format Co-authored-by: Aarni Koskela <[email protected]> * Remove pre-release option for torch install * Update pytorch install path Co-authored-by: Titus <[email protected]> * Add messages for Heuristics error * Remove toolcache for disk space * print disk usage * Clean disk space for linux * Fix for ubuntu * Add sudo for apt clean * Update clean up disk list * remove disk usage print * Add BNB_BACKEND variable * Update diagnostic functions for ROCm * Fix tuple error * Fix library detection bug for recursive and symlink cases * fix pre-commit errors * Remove recursive path lib search * Create function for runtime lib patterns * Update logger format Co-authored-by: Aarni Koskela <[email protected]> * Update error reporting Co-authored-by: Aarni Koskela <[email protected]> * Remove commented code Co-authored-by: Aarni Koskela <[email protected]> * Update error reporting Co-authored-by: Aarni Koskela <[email protected]> * Update error reporting * Create hip diagnostics functions * Fix Typo * Fix pre-commit checks --------- Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Titus <[email protected]> * check grad before using ipex (bitsandbytes-foundation#1358) * Enable packaging for ROCm 6.2 (bitsandbytes-foundation#1367) * Enable 6.2 build * Update documentation for 6.2.0 pip install * Update for VS2022 17.11 compatibility with CUDA < 12.4 (bitsandbytes-foundation#1341) * Update for VS2022 17.11 compatibility with CUDA < 12.4 * Try again * Enable continuous releases for multi-backend-refactor branch * Update release workflow * Publish continuous release for multi-backend * continuous release: revert wheel renaming due to install err * Revert "continuous release: revert wheel renaming due to install err" This reverts commit 0a2b539. * add dynamic tag-based versioning + git hash for dev vers * docs: update w/ changes from `main` * get tags for dynamic versioning * fine-tune continuous release params * reduce the pkg size + build times for the preview release * refine docs for multi-backend alpha release (bitsandbytes-foundation#1380) * refine docs for multi-backend alpha release * docs: further tweaks to multi-backend alpha docs * docs: further tweaks to multi-backend alpha docs * docs: further tweaks to multi-backend alpha docs * docs: add multi-backend feedback links * docs: add request for contributions * docs: small fixes * docs: small fixes * docs: add info about `main` continuous build * docs: further tweaks to multi-backend alpha docs * docs: further tweaks to multi-backend alpha docs * docs: remove 2 obsolete lines --------- Co-authored-by: pnunna93 <[email protected]> Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Titus <[email protected]> Co-authored-by: Matthew Douglas <[email protected]> * Revert "enable backward" This reverts commit cd7bf21. * Revert "use ipex op in backward" This reverts commit b8df1aa. * fix finetune * check training * fix gemv check * reformat * avoid double quant in backward if not needed * Zh/xpu support (#9) * Add xpu support * Add xpu support for int8 * Add xpu dequant kernel support * update code * remove debug comments * remove redundant comments * Add xpu integration for woqlinear * correct the comments * Update cpu_xpu_common.py --------- Co-authored-by: zhuhong61 <[email protected]> Co-authored-by: zhuhong61 <[email protected]> * avoid import triton if CPU and XPU backend * fix setup in docker without git config * xpu do not support compile for now Signed-off-by: jiqing-feng <[email protected]> * update xpu Signed-off-by: jiqing-feng <[email protected]> * update 4bit compute dtype * fix xpu int8 path Signed-off-by: jiqing-feng <[email protected]> * optimize 4bit dequant Signed-off-by: jiqing-feng <[email protected]> * fix xpu dequant Signed-off-by: jiqing-feng <[email protected]> * add empty cache in each xpu op * add nf4 dequant ipex kernel * fix dequant 4bit op * empty cache has negative effect on 4bit gemv * fix xpu save * fix save * xpu use float16 default Signed-off-by: jiqing-feng <[email protected]> * rm empty cache as it cause slower perf Signed-off-by: jiqing-feng <[email protected]> * fix xpu save Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix format * update readme for Intel CPU and XPU do not need make csrc codes * fix format * fix import --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: pnunna93 <[email protected]> Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Titus <[email protected]> Co-authored-by: Matthew Douglas <[email protected]> Co-authored-by: zhuhong61 <[email protected]> Co-authored-by: zhuhong61 <[email protected]>
1 parent 116e1cf commit c0b1a62

File tree

10 files changed

+248
-94
lines changed

10 files changed

+248
-94
lines changed

bitsandbytes/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
matmul_cublas,
1818
mm_cublas,
1919
)
20-
from .backends import register_backend
20+
from .backends import backends, register_backend
2121
from .backends.cpu import CPUBackend
2222
from .backends.npu import NPUBackend
2323
from .cextension import lib
24-
from .nn import modules
2524

2625
features = {"multi_backend"}
2726
supported_torch_devices = {
@@ -76,6 +75,11 @@
7675
if hasattr(torch, "npu") and torch.npu.is_available():
7776
register_backend("npu", NPUBackend())
7877

78+
79+
# import module after decided backends
80+
if backends:
81+
from .nn import modules
82+
7983
# TODO: Other potential backends:
8084
# XLA - Google TPU / PJRT runtime
8185
# HPU - Habana / Intel Gaudi

bitsandbytes/autograd/_functions.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def backward(ctx, grad_output):
221221

222222
def supports_igemmlt(device: torch.device) -> bool:
223223
"""check if this device supports the optimized int8 kernel"""
224-
if device == torch.device("cpu"):
224+
if device == torch.device("cpu") or torch.device("xpu"):
225225
return True
226226
if torch.version.hip:
227227
return False if BNB_HIP_VERSION < 601 else True
@@ -463,7 +463,9 @@ def backward(ctx, grad_output):
463463
if len(grad_output.shape) == 3:
464464
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
465465

466-
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
466+
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None
467+
if req_gradB or (req_gradA and state.CBt):
468+
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
467469
if req_gradB:
468470
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
469471
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
@@ -575,8 +577,15 @@ def matmul_4bit(
575577
bias=None,
576578
):
577579
assert quant_state is not None
578-
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
579-
# CPU backend does not require A to be a vector
580+
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
581+
if getattr(quant_state, "ipex", False):
582+
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
583+
if bias is not None:
584+
out += bias
585+
return out
586+
else:
587+
return MatMul4Bit.apply(A, B, out, bias, quant_state)
588+
elif A.numel() == A.shape[-1] and A.requires_grad == False:
580589
if A.shape[-1] % quant_state.blocksize != 0:
581590
warn(
582591
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
ipex_cpu = ipex if ipex._C._has_cpu() else None
1818
ipex_xpu = ipex if ipex._C._has_xpu() else None
19+
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
1920
except BaseException:
2021
ipex_cpu = None
2122
ipex_xpu = None
@@ -56,7 +57,7 @@ def _ipex_xpu_version_prereq(major, minor):
5657

5758
def _maybe_torch_compile(func):
5859
# torch.compile requires g++ and pytorch >= 2.0
59-
if gxx_available and _torch_version_prereq(2, 0) and os.getenv('PT_HPU_LAZY_MODE',1)==0:
60+
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu and os.getenv('PT_HPU_LAZY_MODE',1)==0:
6061
options = {}
6162
# fx_graph_cache requires pytorch >= 2.2
6263
if _torch_version_prereq(2, 2):
@@ -182,7 +183,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
182183
A_reshaped = A.reshape(m, k)
183184

184185
# torch._int_mm is available on CPU since torch 2.4
185-
if _torch_version_prereq(2, 4):
186+
if _torch_version_prereq(2, 4) and A.device.type == "cpu":
186187
C = torch._int_mm(A_reshaped, B.T).to(dtype)
187188
else:
188189
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
@@ -234,8 +235,10 @@ def mm_dequant_impl(
234235
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
235236

236237
if compute_dtype not in [torch.float32, torch.bfloat16]:
237-
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
238-
compute_dtype = torch.float32
238+
warnings.warn(
239+
f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead"
240+
)
241+
compute_dtype = torch.bfloat16
239242
A_reshaped = A.reshape(out_shape).to(compute_dtype)
240243
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
241244
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
@@ -408,7 +411,6 @@ def dequantize_4bit_impl(
408411
torch.Tensor:
409412
Dequantized tensor.
410413
"""
411-
412414
if A.shape[0] == 1:
413415
transpose = False
414416
A = A.squeeze(0)
@@ -438,23 +440,27 @@ def dequantize_4bit_impl(
438440
if quant_state.nested:
439441
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
440442

441-
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
442-
assert quant_state.op_context is not None
443-
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
444-
A = A.reshape(-1)
445-
absmax = quant_state.op_context.get_scales().reshape(-1)
446-
447-
if out is None:
448-
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
443+
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
444+
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
445+
quant_state.ipex = False
449446

450-
n = out.numel()
451447
# Map nf4 to [-1, 1]
448+
<<<<<<< HEAD
452449
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
453450
out_uint8[::2] = A.bitwise_and(0xF)
454451
out_uint8[1::2] = A.bitwise_right_shift(4)
455452
out_dq = torch.empty(out_uint8.shape, dtype=quant_state.code.dtype, device= quant_state.code.device)
456453
for i in range(len(quant_state.code)):
457454
out_dq[out_uint8 == i] = quant_state.code[i]
455+
=======
456+
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
457+
n = out_dq.numel()
458+
out_dq[::2] = A & 0xF
459+
out_dq[1::2] = A >> 4
460+
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
461+
quant_state.code = quant_state.code.to(quant_state.dtype)
462+
out_dq = quant_state.code[out_dq]
463+
>>>>>>> b2ac423 (Enable XPU and optimize cpu/xpu op (#1418))
458464

459465
# Apply scales
460466
if out_dq.numel() != n:
@@ -464,12 +470,17 @@ def dequantize_4bit_impl(
464470
blocks += 1 if n % blocksize > 0 else 0
465471
rem = n % blocksize
466472
has_rem = rem > 0
467-
out_reshaped = out.reshape(-1)
468-
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
469-
-1
470-
)
473+
471474
if has_rem:
475+
if out is None:
476+
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
477+
out_reshaped = out.reshape(-1)
478+
out_reshaped[: n - rem] = (
479+
out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)
480+
).reshape(-1)
472481
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
482+
else:
483+
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)
473484

474485
# take transpose here because weight is transposed (again) for computation
475486
if transpose:
@@ -510,9 +521,21 @@ def gemm_4bit_impl(
510521
torch.Tensor:
511522
GEMM output tensor.
512523
"""
513-
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
514-
assert state.op_context is not None
515-
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
524+
if getattr(state, "ipex", False):
525+
output = torch.ops.torch_ipex.woq_linear(
526+
A,
527+
B,
528+
"nf4",
529+
state.shape,
530+
state.new_scales,
531+
state.new_zeros,
532+
None,
533+
None,
534+
state.blocksize,
535+
ipex_cpu.quantization.WoqLowpMode.BF16,
536+
1,
537+
state.compensation,
538+
)
516539
else:
517540
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
518541
output = torch.matmul(A, dqB.to(A.dtype))

bitsandbytes/backends/xpu.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,36 @@
55
from bitsandbytes.utils import QuantState
66

77
from .base import Backend
8+
from .cpu_xpu_common import (
9+
dequantize_4bit_impl,
10+
double_quant_impl,
11+
gemm_4bit_impl,
12+
igemmlt_impl,
13+
mm_dequant_impl,
14+
quantize_4bit_impl,
15+
)
16+
17+
Tensor = torch.Tensor
18+
19+
20+
def assert_on_xpu(tensors):
21+
on_xpu = True
22+
for t in tensors:
23+
if t is None:
24+
continue # NULL pointers are fine
25+
on_xpu &= t.device.type == "xpu"
26+
if not on_xpu:
27+
raise TypeError(
28+
"All input tensors need to be on XPU, but found some tensors to not be on XPU:\n"
29+
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
30+
)
31+
return on_xpu
832

933

1034
class XPUBackend(Backend):
35+
mm_dequant_compute_dtype = torch.bfloat16
36+
mm_dequant_output_dtype = torch.bfloat16
37+
1138
def double_quant(
1239
self,
1340
A: torch.Tensor,
@@ -17,7 +44,9 @@ def double_quant(
1744
out_row: Optional[torch.Tensor] = None,
1845
threshold=0.0,
1946
):
20-
raise NotImplementedError
47+
assert_on_xpu([A, col_stats, row_stats, out_col, out_row])
48+
output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
49+
return output
2150

2251
def transform(
2352
self,
@@ -29,7 +58,23 @@ def transform(
2958
state: Optional[Tuple[torch.Size, str]] = None,
3059
ld=None,
3160
):
32-
raise NotImplementedError
61+
"""
62+
Transform tensor A to to_order. It is originally designed for CUDA.
63+
For XPU, it returns the original tensor if transpose=False.
64+
Otherwise, it returns the transpose of A
65+
"""
66+
assert_on_xpu([A, out])
67+
if transpose:
68+
if out is not None:
69+
out.copy_(A.T)
70+
else:
71+
out = A.T
72+
else:
73+
if out is not None:
74+
out.copy_(A)
75+
else:
76+
out = A
77+
return out, state
3378

3479
def igemmlt(
3580
self,
@@ -41,7 +86,9 @@ def igemmlt(
4186
Sout: Optional[Tuple[torch.Size, str]] = None,
4287
dtype=torch.int32,
4388
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
44-
raise NotImplementedError
89+
assert_on_xpu([A, B])
90+
output = igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
91+
return output
4592

4693
def mm_dequant(
4794
self,
@@ -54,15 +101,30 @@ def mm_dequant(
54101
new_col_stats: Optional[torch.Tensor] = None,
55102
bias: Optional[torch.Tensor] = None,
56103
) -> torch.Tensor:
57-
raise NotImplementedError
104+
assert_on_xpu([A, row_stats, col_stats, out, bias])
105+
output = mm_dequant_impl(
106+
A,
107+
quant_state,
108+
row_stats,
109+
col_stats,
110+
out,
111+
new_row_stats,
112+
new_col_stats,
113+
bias,
114+
self.mm_dequant_compute_dtype,
115+
self.mm_dequant_output_dtype,
116+
)
117+
return output
58118

59119
def extract_outliers(
60120
self,
61121
A: torch.Tensor,
62122
SA: Tuple[torch.Size, str],
63123
idx: torch.Tensor,
64124
) -> torch.Tensor:
65-
raise NotImplementedError
125+
assert_on_xpu([A])
126+
output = A[:, idx].contiguous()
127+
return output
66128

67129
def quantize_4bit(
68130
self,
@@ -74,7 +136,12 @@ def quantize_4bit(
74136
quant_type: Literal["fp4", "nf4"] = "fp4",
75137
quant_storage=torch.uint8,
76138
) -> Tuple[torch.Tensor, QuantState]:
77-
raise NotImplementedError
139+
if blocksize is None:
140+
blocksize = 64
141+
assert_on_xpu([A, absmax, out])
142+
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
143+
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
144+
return output
78145

79146
def dequantize_4bit(
80147
self,
@@ -85,7 +152,15 @@ def dequantize_4bit(
85152
blocksize: int = 64,
86153
quant_type: Literal["fp4", "nf4"] = "fp4",
87154
) -> torch.Tensor:
88-
raise NotImplementedError
155+
if blocksize is None:
156+
blocksize = 64
157+
assert_on_xpu([A, absmax, out])
158+
if quant_type == "nf4":
159+
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
160+
else:
161+
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
162+
163+
return output
89164

90165
def gemv_4bit(
91166
self,
@@ -96,7 +171,11 @@ def gemv_4bit(
96171
transposed_B=False,
97172
state: QuantState = None,
98173
) -> torch.Tensor:
99-
raise NotImplementedError
174+
assert_on_xpu([A, B, out])
175+
if state is None:
176+
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
177+
output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
178+
return output
100179

101180
def dequantize_blockwise(
102181
self,

0 commit comments

Comments
 (0)