Skip to content

Commit b2ac423

Browse files
jiqing-fengpnunna93akxTitus-von-Koellermatthewdouglas
authored
Enable XPU and optimize cpu/xpu op (#1418)
* enable new ipex API ipex weight is 4D so we cannot transpose fix dequant check require grad * use ipex op in backward * enable backward * Multi backend refactor (#8) * AMD: Clarify diagnostic messages; free up disk space for CI build * Add build job for rocm * Add rocm build script * Copy shared obj file into output_dir * upload build artifacts and enable wheels build * Remove cuda build temporarily * Add ROCm version to .so filename * Add rocm_version to whls build * Revert "Remove cuda build temporarily" This reverts commit 1413c5f. * Add rocm_version env var * Remove thrush header files * Print node info * print cuda node info * Revert "print cuda node info" This reverts commit cdb209a. * Revert "Print node info" This reverts commit 7e9a65c. * Add rocm arch to compile command * Rename .so files to rocm * Update default gpu arch * Skip cpu based igemmlt int tests on ROCm * Update Documentation * Update upstream repo name * Update docs * Update string format Co-authored-by: Aarni Koskela <[email protected]> * Remove pre-release option for torch install * Update pytorch install path Co-authored-by: Titus <[email protected]> * Add messages for Heuristics error * Remove toolcache for disk space * print disk usage * Clean disk space for linux * Fix for ubuntu * Add sudo for apt clean * Update clean up disk list * remove disk usage print * Add BNB_BACKEND variable * Update diagnostic functions for ROCm * Fix tuple error * Fix library detection bug for recursive and symlink cases * fix pre-commit errors * Remove recursive path lib search * Create function for runtime lib patterns * Update logger format Co-authored-by: Aarni Koskela <[email protected]> * Update error reporting Co-authored-by: Aarni Koskela <[email protected]> * Remove commented code Co-authored-by: Aarni Koskela <[email protected]> * Update error reporting Co-authored-by: Aarni Koskela <[email protected]> * Update error reporting * Create hip diagnostics functions * Fix Typo * Fix pre-commit checks --------- Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Titus <[email protected]> * check grad before using ipex (#1358) * Enable packaging for ROCm 6.2 (#1367) * Enable 6.2 build * Update documentation for 6.2.0 pip install * Update for VS2022 17.11 compatibility with CUDA < 12.4 (#1341) * Update for VS2022 17.11 compatibility with CUDA < 12.4 * Try again * Enable continuous releases for multi-backend-refactor branch * Update release workflow * Publish continuous release for multi-backend * continuous release: revert wheel renaming due to install err * Revert "continuous release: revert wheel renaming due to install err" This reverts commit 0a2b539. * add dynamic tag-based versioning + git hash for dev vers * docs: update w/ changes from `main` * get tags for dynamic versioning * fine-tune continuous release params * reduce the pkg size + build times for the preview release * refine docs for multi-backend alpha release (#1380) * refine docs for multi-backend alpha release * docs: further tweaks to multi-backend alpha docs * docs: further tweaks to multi-backend alpha docs * docs: further tweaks to multi-backend alpha docs * docs: add multi-backend feedback links * docs: add request for contributions * docs: small fixes * docs: small fixes * docs: add info about `main` continuous build * docs: further tweaks to multi-backend alpha docs * docs: further tweaks to multi-backend alpha docs * docs: remove 2 obsolete lines --------- Co-authored-by: pnunna93 <[email protected]> Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Titus <[email protected]> Co-authored-by: Matthew Douglas <[email protected]> * Revert "enable backward" This reverts commit cd7bf21. * Revert "use ipex op in backward" This reverts commit b8df1aa. * fix finetune * check training * fix gemv check * reformat * avoid double quant in backward if not needed * Zh/xpu support (#9) * Add xpu support * Add xpu support for int8 * Add xpu dequant kernel support * update code * remove debug comments * remove redundant comments * Add xpu integration for woqlinear * correct the comments * Update cpu_xpu_common.py --------- Co-authored-by: zhuhong61 <[email protected]> Co-authored-by: zhuhong61 <[email protected]> * avoid import triton if CPU and XPU backend * fix setup in docker without git config * xpu do not support compile for now Signed-off-by: jiqing-feng <[email protected]> * update xpu Signed-off-by: jiqing-feng <[email protected]> * update 4bit compute dtype * fix xpu int8 path Signed-off-by: jiqing-feng <[email protected]> * optimize 4bit dequant Signed-off-by: jiqing-feng <[email protected]> * fix xpu dequant Signed-off-by: jiqing-feng <[email protected]> * add empty cache in each xpu op * add nf4 dequant ipex kernel * fix dequant 4bit op * empty cache has negative effect on 4bit gemv * fix xpu save * fix save * xpu use float16 default Signed-off-by: jiqing-feng <[email protected]> * rm empty cache as it cause slower perf Signed-off-by: jiqing-feng <[email protected]> * fix xpu save Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix 8bit int8 param device Signed-off-by: jiqing-feng <[email protected]> * fix format * update readme for Intel CPU and XPU do not need make csrc codes * fix format * fix import --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: pnunna93 <[email protected]> Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Titus <[email protected]> Co-authored-by: Matthew Douglas <[email protected]> Co-authored-by: zhuhong61 <[email protected]> Co-authored-by: zhuhong61 <[email protected]>
1 parent cd73601 commit b2ac423

File tree

10 files changed

+246
-101
lines changed

10 files changed

+246
-101
lines changed

bitsandbytes/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
matmul_cublas,
1818
mm_cublas,
1919
)
20-
from .backends import register_backend
20+
from .backends import backends, register_backend
2121
from .backends.cpu import CPUBackend
2222
from .backends.npu import NPUBackend
2323
from .cextension import lib
24-
from .nn import modules
2524

2625
features = {"multi_backend"}
2726
supported_torch_devices = {
@@ -64,6 +63,11 @@
6463
if hasattr(torch, "npu") and torch.npu.is_available():
6564
register_backend("npu", NPUBackend())
6665

66+
67+
# import module after decided backends
68+
if backends:
69+
from .nn import modules
70+
6771
# TODO: Other potential backends:
6872
# XLA - Google TPU / PJRT runtime
6973
# HPU - Habana / Intel Gaudi

bitsandbytes/autograd/_functions.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def backward(ctx, grad_output):
221221

222222
def supports_igemmlt(device: torch.device) -> bool:
223223
"""check if this device supports the optimized int8 kernel"""
224-
if device == torch.device("cpu"):
224+
if device == torch.device("cpu") or torch.device("xpu"):
225225
return True
226226
if torch.version.hip:
227227
return False if BNB_HIP_VERSION < 601 else True
@@ -463,7 +463,9 @@ def backward(ctx, grad_output):
463463
if len(grad_output.shape) == 3:
464464
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
465465

466-
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
466+
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None
467+
if req_gradB or (req_gradA and state.CBt):
468+
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
467469
if req_gradB:
468470
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
469471
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
@@ -575,8 +577,15 @@ def matmul_4bit(
575577
bias=None,
576578
):
577579
assert quant_state is not None
578-
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
579-
# CPU backend does not require A to be a vector
580+
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
581+
if getattr(quant_state, "ipex", False):
582+
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
583+
if bias is not None:
584+
out += bias
585+
return out
586+
else:
587+
return MatMul4Bit.apply(A, B, out, bias, quant_state)
588+
elif A.numel() == A.shape[-1] and A.requires_grad == False:
580589
if A.shape[-1] % quant_state.blocksize != 0:
581590
warn(
582591
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
ipex_cpu = ipex if ipex._C._has_cpu() else None
1717
ipex_xpu = ipex if ipex._C._has_xpu() else None
18+
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
1819
except BaseException:
1920
ipex_cpu = None
2021
ipex_xpu = None
@@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):
5556

5657
def _maybe_torch_compile(func):
5758
# torch.compile requires g++ and pytorch >= 2.0
58-
if gxx_available and _torch_version_prereq(2, 0):
59+
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
5960
options = {}
6061
# fx_graph_cache requires pytorch >= 2.2
6162
if _torch_version_prereq(2, 2):
@@ -181,7 +182,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
181182
A_reshaped = A.reshape(m, k)
182183

183184
# torch._int_mm is available on CPU since torch 2.4
184-
if _torch_version_prereq(2, 4):
185+
if _torch_version_prereq(2, 4) and A.device.type == "cpu":
185186
C = torch._int_mm(A_reshaped, B.T).to(dtype)
186187
else:
187188
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
@@ -233,8 +234,10 @@ def mm_dequant_impl(
233234
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
234235

235236
if compute_dtype not in [torch.float32, torch.bfloat16]:
236-
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
237-
compute_dtype = torch.float32
237+
warnings.warn(
238+
f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead"
239+
)
240+
compute_dtype = torch.bfloat16
238241
A_reshaped = A.reshape(out_shape).to(compute_dtype)
239242
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
240243
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
@@ -342,7 +345,7 @@ def quantize_4bit_impl(
342345
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
343346
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
344347
# map [-1, 1] to nf4/fp4
345-
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
348+
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
346349
if quant_type == "nf4":
347350
for i in range(len(NF4_QUANT_TABLE)):
348351
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
@@ -408,7 +411,6 @@ def dequantize_4bit_impl(
408411
torch.Tensor:
409412
Dequantized tensor.
410413
"""
411-
412414
if A.shape[0] == 1:
413415
transpose = False
414416
A = A.squeeze(0)
@@ -438,23 +440,18 @@ def dequantize_4bit_impl(
438440
if quant_state.nested:
439441
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
440442

441-
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
442-
assert quant_state.op_context is not None
443-
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
444-
A = A.reshape(-1)
445-
absmax = quant_state.op_context.get_scales().reshape(-1)
446-
447-
if out is None:
448-
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
443+
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
444+
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
445+
quant_state.ipex = False
449446

450-
n = out.numel()
451447
# Map nf4 to [-1, 1]
452-
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
453-
out_uint8[::2] = A.bitwise_and(0xF)
454-
out_uint8[1::2] = A.bitwise_right_shift(4)
455-
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
456-
for i in range(len(quant_state.code)):
457-
out_dq[out_uint8 == i] = quant_state.code[i]
448+
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
449+
n = out_dq.numel()
450+
out_dq[::2] = A & 0xF
451+
out_dq[1::2] = A >> 4
452+
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
453+
quant_state.code = quant_state.code.to(quant_state.dtype)
454+
out_dq = quant_state.code[out_dq]
458455

459456
# Apply scales
460457
if out_dq.numel() != n:
@@ -464,12 +461,17 @@ def dequantize_4bit_impl(
464461
blocks += 1 if n % blocksize > 0 else 0
465462
rem = n % blocksize
466463
has_rem = rem > 0
467-
out_reshaped = out.reshape(-1)
468-
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
469-
-1
470-
)
464+
471465
if has_rem:
466+
if out is None:
467+
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
468+
out_reshaped = out.reshape(-1)
469+
out_reshaped[: n - rem] = (
470+
out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)
471+
).reshape(-1)
472472
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
473+
else:
474+
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)
473475

474476
# take transpose here because weight is transposed (again) for computation
475477
if transpose:
@@ -510,9 +512,21 @@ def gemm_4bit_impl(
510512
torch.Tensor:
511513
GEMM output tensor.
512514
"""
513-
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
514-
assert state.op_context is not None
515-
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
515+
if getattr(state, "ipex", False):
516+
output = torch.ops.torch_ipex.woq_linear(
517+
A,
518+
B,
519+
"nf4",
520+
state.shape,
521+
state.new_scales,
522+
state.new_zeros,
523+
None,
524+
None,
525+
state.blocksize,
526+
ipex_cpu.quantization.WoqLowpMode.BF16,
527+
1,
528+
state.compensation,
529+
)
516530
else:
517531
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
518532
output = torch.matmul(A, dqB.to(A.dtype))

bitsandbytes/backends/xpu.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,36 @@
55
from bitsandbytes.utils import QuantState
66

77
from .base import Backend
8+
from .cpu_xpu_common import (
9+
dequantize_4bit_impl,
10+
double_quant_impl,
11+
gemm_4bit_impl,
12+
igemmlt_impl,
13+
mm_dequant_impl,
14+
quantize_4bit_impl,
15+
)
16+
17+
Tensor = torch.Tensor
18+
19+
20+
def assert_on_xpu(tensors):
21+
on_xpu = True
22+
for t in tensors:
23+
if t is None:
24+
continue # NULL pointers are fine
25+
on_xpu &= t.device.type == "xpu"
26+
if not on_xpu:
27+
raise TypeError(
28+
"All input tensors need to be on XPU, but found some tensors to not be on XPU:\n"
29+
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
30+
)
31+
return on_xpu
832

933

1034
class XPUBackend(Backend):
35+
mm_dequant_compute_dtype = torch.bfloat16
36+
mm_dequant_output_dtype = torch.bfloat16
37+
1138
def double_quant(
1239
self,
1340
A: torch.Tensor,
@@ -17,7 +44,9 @@ def double_quant(
1744
out_row: Optional[torch.Tensor] = None,
1845
threshold=0.0,
1946
):
20-
raise NotImplementedError
47+
assert_on_xpu([A, col_stats, row_stats, out_col, out_row])
48+
output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
49+
return output
2150

2251
def transform(
2352
self,
@@ -29,7 +58,23 @@ def transform(
2958
state: Optional[Tuple[torch.Size, str]] = None,
3059
ld=None,
3160
):
32-
raise NotImplementedError
61+
"""
62+
Transform tensor A to to_order. It is originally designed for CUDA.
63+
For XPU, it returns the original tensor if transpose=False.
64+
Otherwise, it returns the transpose of A
65+
"""
66+
assert_on_xpu([A, out])
67+
if transpose:
68+
if out is not None:
69+
out.copy_(A.T)
70+
else:
71+
out = A.T
72+
else:
73+
if out is not None:
74+
out.copy_(A)
75+
else:
76+
out = A
77+
return out, state
3378

3479
def igemmlt(
3580
self,
@@ -41,7 +86,9 @@ def igemmlt(
4186
Sout: Optional[Tuple[torch.Size, str]] = None,
4287
dtype=torch.int32,
4388
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
44-
raise NotImplementedError
89+
assert_on_xpu([A, B])
90+
output = igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
91+
return output
4592

4693
def mm_dequant(
4794
self,
@@ -54,15 +101,30 @@ def mm_dequant(
54101
new_col_stats: Optional[torch.Tensor] = None,
55102
bias: Optional[torch.Tensor] = None,
56103
) -> torch.Tensor:
57-
raise NotImplementedError
104+
assert_on_xpu([A, row_stats, col_stats, out, bias])
105+
output = mm_dequant_impl(
106+
A,
107+
quant_state,
108+
row_stats,
109+
col_stats,
110+
out,
111+
new_row_stats,
112+
new_col_stats,
113+
bias,
114+
self.mm_dequant_compute_dtype,
115+
self.mm_dequant_output_dtype,
116+
)
117+
return output
58118

59119
def extract_outliers(
60120
self,
61121
A: torch.Tensor,
62122
SA: Tuple[torch.Size, str],
63123
idx: torch.Tensor,
64124
) -> torch.Tensor:
65-
raise NotImplementedError
125+
assert_on_xpu([A])
126+
output = A[:, idx].contiguous()
127+
return output
66128

67129
def quantize_4bit(
68130
self,
@@ -74,7 +136,12 @@ def quantize_4bit(
74136
quant_type: Literal["fp4", "nf4"] = "fp4",
75137
quant_storage=torch.uint8,
76138
) -> Tuple[torch.Tensor, QuantState]:
77-
raise NotImplementedError
139+
if blocksize is None:
140+
blocksize = 64
141+
assert_on_xpu([A, absmax, out])
142+
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
143+
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
144+
return output
78145

79146
def dequantize_4bit(
80147
self,
@@ -85,7 +152,15 @@ def dequantize_4bit(
85152
blocksize: int = 64,
86153
quant_type: Literal["fp4", "nf4"] = "fp4",
87154
) -> torch.Tensor:
88-
raise NotImplementedError
155+
if blocksize is None:
156+
blocksize = 64
157+
assert_on_xpu([A, absmax, out])
158+
if quant_type == "nf4":
159+
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
160+
else:
161+
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
162+
163+
return output
89164

90165
def gemv_4bit(
91166
self,
@@ -96,7 +171,11 @@ def gemv_4bit(
96171
transposed_B=False,
97172
state: QuantState = None,
98173
) -> torch.Tensor:
99-
raise NotImplementedError
174+
assert_on_xpu([A, B, out])
175+
if state is None:
176+
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
177+
output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
178+
return output
100179

101180
def dequantize_blockwise(
102181
self,

0 commit comments

Comments
 (0)