Skip to content

Commit 1ccd3d7

Browse files
committed
supports hpu backend in main branch
1 parent aaa71d7 commit 1ccd3d7

File tree

6 files changed

+82
-3
lines changed

6 files changed

+82
-3
lines changed

bitsandbytes/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"cpu",
2727
"cuda", # NVIDIA/AMD GPU
2828
"xpu", # Intel GPU
29-
"hpu", # Gaudi
29+
"hpu", # Intel Gaudi
3030
"npu", # Ascend NPU
3131
"mps", # Apple Silicon
3232
}
@@ -37,6 +37,9 @@
3737
if torch.xpu.is_available():
3838
from .backends.xpu import ops as xpu_ops
3939

40+
if hasattr(torch, "hpu") and torch.hpu.is_available():
41+
from .backends.hpu import ops as hpu_ops
42+
4043

4144
def _import_backends():
4245
"""

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def matmul_4bit(
451451
else:
452452
return MatMul4Bit.apply(A, B, out, bias, quant_state)
453453

454-
if A.numel() == A.shape[-1] and A.requires_grad == False:
454+
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
455455
if A.shape[-1] % quant_state.blocksize != 0:
456456
warn(
457457
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/hpu/__init__.py

Whitespace-only changes.

bitsandbytes/backends/hpu/ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from collections.abc import Sequence
2+
import math
3+
4+
import torch
5+
6+
from bitsandbytes.utils import _reverse_4bit_compress_format
7+
8+
from ..._ops import register_kernel
9+
from ..utils import GAUDI_SW_VER
10+
11+
12+
@register_kernel("bitsandbytes::dequantize_4bit", "hpu")
13+
def _(
14+
A: torch.Tensor,
15+
absmax: torch.Tensor,
16+
blocksize: int,
17+
quant_type: str,
18+
shape: Sequence[int],
19+
dtype: torch.dtype,
20+
) -> torch.Tensor:
21+
torch._check_is_size(blocksize)
22+
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
23+
torch._check(
24+
dtype in (torch.bfloat16, torch.float32), lambda: f"4bit dequantization only bf16/f32, but got {dtype}"
25+
)
26+
torch._check(A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8, but got {A.dtype}")
27+
28+
# Enable non uint8 dtype
29+
if A.dtype != torch.uint8:
30+
A = A.view(torch.uint8)
31+
32+
transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True
33+
34+
A = A.reshape(-1)
35+
36+
if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22):
37+
A = _reverse_4bit_compress_format(A)
38+
39+
# HPU dequantization function for NF4 quantized tensors.
40+
out_dq = torch.ops.hpu.dequantize_nf4(
41+
A,
42+
absmax.to(dtype),
43+
blocksize,
44+
out_shape=(math.prod(shape),),
45+
out_dtype=dtype,
46+
)
47+
48+
output = out_dq.reshape(shape)
49+
50+
if transpose:
51+
output = output.t()
52+
53+
return output

bitsandbytes/backends/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import subprocess
2+
3+
from packaging import version
14
import torch
25

36
try:
@@ -55,3 +58,23 @@
5558
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now.
5659
)
5760
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}
61+
62+
63+
def get_gaudi_sw_version():
64+
"""
65+
Returns the installed version of Gaudi SW.
66+
"""
67+
output = subprocess.run(
68+
"pip list | grep habana-torch-plugin",
69+
shell=True,
70+
text=True,
71+
capture_output=True,
72+
)
73+
# If grep return nothing
74+
if not output.stdout.strip():
75+
return None
76+
77+
return version.parse(output.stdout.split("\n")[0].split()[-1])
78+
79+
80+
GAUDI_SW_VER = get_gaudi_sw_version()

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def __init__(
442442
)
443443
# self.persistent_buffers = [] # TODO consider as way to save quant state
444444
self.compute_dtype = compute_dtype
445-
self.compute_type_is_set = False
445+
self.compute_type_is_set = False if compute_dtype is None else True
446446
self.quant_state = None
447447
self.quant_storage = quant_storage
448448
self.ipex_linear_is_set = False

0 commit comments

Comments
 (0)