Skip to content

Commit 797c31a

Browse files
committed
supports quantization on HPU
1 parent cd73601 commit 797c31a

File tree

4 files changed

+243
-5
lines changed

4 files changed

+243
-5
lines changed

bitsandbytes/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,23 @@
2828
"cuda", # includes ROCm
2929
"xpu", # Intel GPU
3030
"cpu",
31+
"hpu",
3132
}
3233

3334
# Always register the CPU backend.
3435
register_backend("cpu", CPUBackend())
3536

37+
# Register HPU Backend, if available
38+
try:
39+
import habana_frameworks.torch
40+
41+
if hasattr(torch, "hpu") and torch.hpu.is_available():
42+
from .backends.hpu import HPUBackend
43+
44+
register_backend("hpu", HPUBackend())
45+
except ImportError:
46+
print("Unable to register HPU")
47+
3648
# Register either CUDA or ROCm backend, if available.
3749
# Only one of these backends can be used at a time, since the torch.device semantics are
3850
# the same for both torch+rocm and torch+cuda (e.g. device name is "cuda")

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def mm_dequant_impl(
277277
}
278278

279279

280-
@_maybe_torch_compile
280+
# @_maybe_torch_compile
281281
def quantize_4bit_impl(
282282
A: Tensor,
283283
absmax: Tensor = None,
@@ -342,7 +342,7 @@ def quantize_4bit_impl(
342342
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
343343
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
344344
# map [-1, 1] to nf4/fp4
345-
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
345+
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=scaled_A.device)
346346
if quant_type == "nf4":
347347
for i in range(len(NF4_QUANT_TABLE)):
348348
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
@@ -373,7 +373,7 @@ def quantize_4bit_impl(
373373
return out.unsqueeze(0), state
374374

375375

376-
@_maybe_torch_compile
376+
#@_maybe_torch_compile
377377
def dequantize_4bit_impl(
378378
A: Tensor,
379379
quant_state=None,
@@ -452,7 +452,7 @@ def dequantize_4bit_impl(
452452
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
453453
out_uint8[::2] = A.bitwise_and(0xF)
454454
out_uint8[1::2] = A.bitwise_right_shift(4)
455-
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
455+
out_dq = torch.empty(out_uint8.shape, dtype=quant_state.code.dtype, device= quant_state.code.device)
456456
for i in range(len(quant_state.code)):
457457
out_dq[out_uint8 == i] = quant_state.code[i]
458458

bitsandbytes/backends/hpu.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from typing import Literal, Optional, Tuple, Union
2+
3+
import torch
4+
5+
from bitsandbytes.utils import QuantState
6+
7+
from .base import Backend
8+
from .cpu_xpu_common import (
9+
dequantize_4bit_impl,
10+
double_quant_impl,
11+
gemm_4bit_impl,
12+
igemmlt_impl,
13+
mm_dequant_impl,
14+
quantize_4bit_impl,
15+
)
16+
17+
Tensor = torch.Tensor
18+
19+
20+
def assert_on_hpu(tensors):
21+
on_hpu = True
22+
for t in tensors:
23+
if t is None:
24+
continue # NULL pointers are fine
25+
on_hpu &= t.device.type == "hpu"
26+
if not on_hpu:
27+
raise TypeError(
28+
"All input tensors need to be on HPU, but found some tensors to not be on HPU:\n"
29+
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
30+
)
31+
return on_hpu
32+
33+
34+
class HPUBackend(Backend):
35+
mm_dequant_compute_dtype = torch.bfloat16
36+
mm_dequant_output_dtype = torch.bfloat16
37+
38+
def double_quant(
39+
self,
40+
A: torch.Tensor,
41+
col_stats: Optional[torch.Tensor] = None,
42+
row_stats: Optional[torch.Tensor] = None,
43+
out_col: Optional[torch.Tensor] = None,
44+
out_row: Optional[torch.Tensor] = None,
45+
threshold=0.0,
46+
):
47+
raise NotImplementedError("Not yet implemented for HPU backend")
48+
49+
def transform(
50+
self,
51+
A: torch.Tensor,
52+
to_order: str,
53+
from_order="row",
54+
out: Optional[torch.Tensor] = None,
55+
transpose=False,
56+
state: Optional[Tuple[torch.Size, str]] = None,
57+
ld=None,
58+
):
59+
"""
60+
Transform tensor A to to_order. It is originally designed for CUDA.
61+
For CPU, it returns the original tensor if transpose=False.
62+
Otherwise, it returns the transpose of A
63+
"""
64+
raise NotImplementedError("Not yet implemented for HPU backend")
65+
66+
def igemmlt(
67+
self,
68+
A: torch.Tensor,
69+
B: torch.Tensor,
70+
SA: Tuple[torch.Size, str],
71+
SB: Tuple[torch.Size, str],
72+
out: Optional[torch.Tensor] = None,
73+
Sout: Optional[Tuple[torch.Size, str]] = None,
74+
dtype=torch.int32,
75+
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
76+
assert_on_hpu([A, B])
77+
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
78+
79+
def mm_dequant(
80+
self,
81+
A: torch.Tensor,
82+
quant_state: Tuple[torch.Size, str],
83+
row_stats: torch.Tensor,
84+
col_stats: torch.Tensor,
85+
out: Optional[torch.Tensor] = None,
86+
new_row_stats: Optional[torch.Tensor] = None,
87+
new_col_stats: Optional[torch.Tensor] = None,
88+
bias: Optional[torch.Tensor] = None,
89+
) -> torch.Tensor:
90+
assert_on_hpu([A, row_stats, col_stats, out, bias])
91+
return mm_dequant_impl(
92+
A,
93+
quant_state,
94+
row_stats,
95+
col_stats,
96+
out,
97+
new_row_stats,
98+
new_col_stats,
99+
bias,
100+
self.mm_dequant_compute_dtype,
101+
self.mm_dequant_output_dtype,
102+
)
103+
104+
def extract_outliers(
105+
self,
106+
A: torch.Tensor,
107+
SA: Tuple[torch.Size, str],
108+
idx: torch.Tensor,
109+
) -> torch.Tensor:
110+
"""
111+
Extract columns of A by idx
112+
"""
113+
assert_on_hpu([A])
114+
return A[:, idx].contiguous()
115+
116+
def quantize_4bit(
117+
self,
118+
A: torch.Tensor,
119+
absmax: Optional[torch.Tensor] = None,
120+
out: Optional[torch.Tensor] = None,
121+
blocksize=64,
122+
compress_statistics=False,
123+
quant_type: Literal["fp4", "nf4"] = "fp4",
124+
quant_storage=torch.uint8,
125+
) -> Tuple[torch.Tensor, QuantState]:
126+
if blocksize is None:
127+
blocksize = 64
128+
129+
assert_on_hpu([A, absmax, out])
130+
assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage"
131+
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
132+
133+
def dequantize_4bit(
134+
self,
135+
A: torch.Tensor,
136+
quant_state: Optional[QuantState] = None,
137+
absmax: Optional[torch.Tensor] = None,
138+
out: Optional[torch.Tensor] = None,
139+
blocksize: int = 64,
140+
quant_type: Literal["fp4", "nf4"] = "fp4",
141+
) -> torch.Tensor:
142+
if blocksize is None:
143+
blocksize = 64
144+
145+
assert_on_hpu([A, absmax, out])
146+
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
147+
148+
def gemv_4bit(
149+
self,
150+
A: torch.Tensor,
151+
B: torch.Tensor,
152+
out: Optional[torch.Tensor] = None,
153+
transposed_A=False,
154+
transposed_B=False,
155+
state: QuantState = None,
156+
) -> torch.Tensor:
157+
assert_on_hpu([A, B, out])
158+
if state is None:
159+
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
160+
161+
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
162+
163+
def dequantize_blockwise(
164+
self,
165+
A: torch.Tensor,
166+
quant_state: Optional[QuantState] = None,
167+
absmax: Optional[torch.Tensor] = None,
168+
code: Optional[torch.Tensor] = None,
169+
out: Optional[torch.Tensor] = None,
170+
blocksize: int = 4096,
171+
nested=False,
172+
) -> torch.Tensor:
173+
raise NotImplementedError("Not yet implemented for HPU backend")
174+
175+
def quantize_blockwise(
176+
self,
177+
A: torch.Tensor,
178+
code: Optional[torch.Tensor] = None,
179+
absmax: Optional[torch.Tensor] = None,
180+
out: Optional[torch.Tensor] = None,
181+
blocksize=4096,
182+
nested=False,
183+
) -> Tuple[torch.Tensor, QuantState]:
184+
raise NotImplementedError("Not yet implemented for HPU backend")
185+
186+
def optimizer_update_8bit_blockwise(
187+
self,
188+
optimizer_name: str,
189+
g: torch.Tensor,
190+
p: torch.Tensor,
191+
state1: torch.Tensor,
192+
state2: Optional[torch.Tensor],
193+
beta1: float,
194+
beta2: float,
195+
eps: float,
196+
step: int,
197+
lr: float,
198+
qmap1: torch.Tensor,
199+
qmap2: Optional[torch.Tensor],
200+
absmax1: torch.Tensor,
201+
absmax2: Optional[torch.Tensor],
202+
weight_decay: float = 0.0,
203+
gnorm_scale: float = 1.0,
204+
skip_zeros=False,
205+
) -> None:
206+
raise NotImplementedError("Not yet implemented for HPU backend")
207+
208+
def optimizer_update_32bit(
209+
self,
210+
optimizer_name: str,
211+
g: torch.Tensor,
212+
p: torch.Tensor,
213+
state1: torch.Tensor,
214+
beta1: float,
215+
eps: float,
216+
step: int,
217+
lr: float,
218+
state2: Optional[torch.Tensor] = None,
219+
beta2: float = 0.0,
220+
weight_decay: float = 0.0,
221+
gnorm_scale: float = 1.0,
222+
unorm_vec: Optional[torch.Tensor] = None,
223+
max_unorm: float = 0.0,
224+
skip_zeros=False,
225+
) -> None:
226+
raise NotImplementedError("Not yet implemented for HPU backend")

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
331331
def to(self, *args, **kwargs):
332332
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
333333

334-
if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
334+
if device is not None and device.type in ["cuda", "cpu", "hpu"] and not self.bnb_quantized:
335335
return self._quantize(device)
336336
else:
337337
if self.quant_state is not None:

0 commit comments

Comments
 (0)