Skip to content

Commit b090d85

Browse files
ckvermaAIrsshaik1
andauthored
HPU support for bitsandbytes (#1592)
Authored by: Chetan Kumar Verma <[email protected]> Co-authored-by: Ruheena Suhani Shaik <[email protected]> Co-authored-by: Bhargav Eede <[email protected]> Co-authored-by: Vivek Goel <[email protected]> Co-authored-by: Ruheena Suhani Shaik <[email protected]>
1 parent 5c48b33 commit b090d85

File tree

3 files changed

+323
-1
lines changed

3 files changed

+323
-1
lines changed

bitsandbytes/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,18 @@
2828
"npu", # Ascend NPU
2929
"xpu", # Intel GPU
3030
"cpu",
31+
"hpu", # Intel Gaudi
3132
}
3233

3334
# Always register the CPU backend.
3435
register_backend("cpu", CPUBackend())
3536

37+
# Register HPU Backend, if available
38+
if hasattr(torch, "hpu") and torch.hpu.is_available():
39+
from .backends.hpu import HPUBackend
40+
41+
register_backend("hpu", HPUBackend())
42+
3643
# Register either CUDA or ROCm backend, if available.
3744
# Only one of these backends can be used at a time, since the torch.device semantics are
3845
# the same for both torch+rocm and torch+cuda (e.g. device name is "cuda")

bitsandbytes/backends/hpu.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import math
2+
from typing import Literal, Optional, Tuple
3+
import warnings
4+
import torch
5+
6+
from bitsandbytes.utils import QuantState
7+
8+
from .base import Backend
9+
from .cpu_xpu_common import (
10+
double_quant_impl,
11+
dequant_8bit,
12+
NF4_QUANT_TABLE,
13+
INT8_QUANT_TABLE,
14+
)
15+
from bitsandbytes.functional import (
16+
QuantState,
17+
get_4bit_type,
18+
)
19+
20+
Tensor = torch.Tensor
21+
22+
def assert_on_hpu(tensors):
23+
on_hpu = True
24+
for t in tensors:
25+
if t is None:
26+
continue # NULL pointers are fine
27+
on_hpu &= t.device.type == "hpu"
28+
if not on_hpu:
29+
raise TypeError(
30+
"All input tensors need to be on HPU, but found some tensors to not be on HPU:\n"
31+
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
32+
)
33+
return on_hpu
34+
35+
class HPUBackend(Backend):
36+
37+
def int8_double_quant(
38+
self,
39+
A: torch.Tensor,
40+
col_stats: Optional[torch.Tensor] = None,
41+
row_stats: Optional[torch.Tensor] = None,
42+
out_col: Optional[torch.Tensor] = None,
43+
out_row: Optional[torch.Tensor] = None,
44+
threshold=0.0,
45+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
46+
assert_on_hpu([A, col_stats, row_stats, out_col, out_row])
47+
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
48+
49+
def transform(
50+
self,
51+
A: torch.Tensor,
52+
to_order: str,
53+
from_order="row",
54+
out: Optional[torch.Tensor] = None,
55+
transpose=False,
56+
state: Optional[Tuple[torch.Size, str]] = None,
57+
ld=None,
58+
):
59+
raise NotImplementedError("Not yet implemented for HPU backend")
60+
61+
def int8_linear_matmul(
62+
self,
63+
A: torch.Tensor,
64+
B: torch.Tensor,
65+
out: Optional[torch.Tensor] = None,
66+
dtype=torch.int32,
67+
) -> torch.Tensor:
68+
raise NotImplementedError("Not yet implemented for HPU backend")
69+
70+
def int8_mm_dequant(
71+
self,
72+
A: torch.Tensor,
73+
row_stats: torch.Tensor,
74+
col_stats: torch.Tensor,
75+
out: Optional[torch.Tensor] = None,
76+
bias: Optional[torch.Tensor] = None,
77+
) -> torch.Tensor:
78+
raise NotImplementedError("Not yet implemented for HPU backend")
79+
80+
def extract_outliers(
81+
self,
82+
A: torch.Tensor,
83+
SA: Tuple[torch.Size, str],
84+
idx: torch.Tensor,
85+
) -> torch.Tensor:
86+
raise NotImplementedError("Not yet implemented for HPU backend")
87+
88+
def quantize_4bit(
89+
self,
90+
A: torch.Tensor,
91+
absmax: Optional[torch.Tensor] = None,
92+
out: Optional[torch.Tensor] = None,
93+
blocksize=64,
94+
compress_statistics=False,
95+
quant_type: Literal["nf4"] = "nf4",
96+
quant_storage=torch.uint8,
97+
) -> Tuple[torch.Tensor, QuantState]:
98+
if blocksize is None:
99+
blocksize = 64
100+
assert_on_hpu([A, absmax, out])
101+
assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage"
102+
return self.quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
103+
104+
def quantize_4bit_impl(
105+
self,
106+
A: Tensor,
107+
absmax: Tensor = None,
108+
out: Tensor = None,
109+
blocksize=64,
110+
compress_statistics=False,
111+
quant_type="nf4",
112+
) -> Tensor:
113+
if quant_type not in ["nf4", "int8"]:
114+
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for HPU.")
115+
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
116+
n = A.numel()
117+
input_shape = A.shape
118+
blocks = n // blocksize
119+
blocks += 1 if n % blocksize > 0 else 0
120+
121+
if absmax is None:
122+
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
123+
124+
if out is None:
125+
out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device)
126+
127+
rem = n % blocksize
128+
has_rem = rem > 0
129+
130+
# Scale tensor to [-1, 1]
131+
A_reshaped = A.reshape(n)
132+
A_com = A_reshaped[: n - rem]
133+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
134+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
135+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
136+
scaled_A = scaled_A.reshape(-1)
137+
if has_rem:
138+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
139+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
140+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
141+
# map [-1, 1] to nf4
142+
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
143+
if quant_type == "nf4":
144+
for i in range(len(NF4_QUANT_TABLE)):
145+
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
146+
elif quant_type == "int8":
147+
map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device)
148+
diff = torch.abs(scaled_A.unsqueeze(-1) - map)
149+
out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device)
150+
151+
if quant_type == "int8":
152+
out = out_uint8
153+
code = torch.Tensor(INT8_QUANT_TABLE).to(A.device)
154+
else:
155+
if out_uint8.size(-1) % 2:
156+
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
157+
# To align with HPU dequantize operator
158+
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
159+
code = get_4bit_type(quant_type, device=A.device)
160+
161+
if compress_statistics:
162+
raise AssertionError("Double quantization is not supported for HPU backend")
163+
offset = absmax.mean()
164+
absmax -= offset
165+
qabsmax, state2 = self.hpu_quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
166+
del absmax
167+
state = QuantState(
168+
absmax=qabsmax,
169+
shape=input_shape,
170+
dtype=A.dtype,
171+
blocksize=blocksize,
172+
code=code,
173+
quant_type=quant_type,
174+
offset=offset,
175+
state2=state2,
176+
)
177+
else:
178+
state = QuantState(
179+
absmax=absmax,
180+
shape=input_shape,
181+
dtype=A.dtype,
182+
blocksize=blocksize,
183+
code=code,
184+
quant_type=quant_type,
185+
)
186+
return out, state
187+
188+
def dequantize_nf4_impl(
189+
self,
190+
input: torch.Tensor,
191+
absmax: torch.Tensor,
192+
blocksize: int,
193+
quant_state: QuantState,
194+
) -> torch.Tensor:
195+
"""
196+
HPU dequantization function for NF4 quantized tensors.
197+
"""
198+
assert_on_hpu([input, absmax])
199+
out_shape = (math.prod(quant_state.shape), )
200+
out_dq = torch.ops.hpu.dequantize_nf4(input, absmax, blocksize,
201+
out_shape=out_shape,
202+
out_dtype=quant_state.dtype)
203+
output = out_dq.reshape(quant_state.shape).T
204+
return output
205+
206+
def dequantize_4bit(
207+
self,
208+
A: torch.Tensor,
209+
quant_state: Optional[QuantState] = None,
210+
absmax: Optional[torch.Tensor] = None,
211+
out: Optional[torch.Tensor] = None,
212+
blocksize: int = 64,
213+
quant_type: Literal["nf4"] = "nf4",
214+
) -> torch.Tensor:
215+
if blocksize is None:
216+
blocksize = 64
217+
218+
assert_on_hpu([A, absmax, out])
219+
if quant_state.nested:
220+
raise AssertionError("Double quantization is not supported for HPU backend")
221+
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)
222+
return self.dequantize_nf4_impl(A, absmax, blocksize, quant_state)
223+
224+
def gemv_4bit(
225+
self,
226+
A: torch.Tensor,
227+
B: torch.Tensor,
228+
out: Optional[torch.Tensor] = None,
229+
transposed_A=False,
230+
transposed_B=False,
231+
state: QuantState = None,
232+
) -> torch.Tensor:
233+
assert_on_hpu([A, B, out])
234+
if state is None:
235+
raise ValueError(
236+
"state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
237+
)
238+
dqB = self.dequantize_nf4_impl(B, state.absmax, state.blocksize, state)
239+
output = torch.matmul(A, dqB.to(A.dtype))
240+
if out is not None:
241+
out.copy_(output)
242+
else:
243+
out = output
244+
return out
245+
246+
def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor):
247+
raise NotImplementedError("Not yet implemented for HPU backend")
248+
249+
def int8_vectorwise_quant(self, A: torch.Tensor, threshold=0.0):
250+
raise NotImplementedError("Not yet implemented for HPU backend")
251+
252+
def dequantize_blockwise(
253+
self,
254+
A: torch.Tensor,
255+
quant_state: Optional[QuantState] = None,
256+
absmax: Optional[torch.Tensor] = None,
257+
code: Optional[torch.Tensor] = None,
258+
out: Optional[torch.Tensor] = None,
259+
blocksize: int = 4096,
260+
nested=False,
261+
) -> torch.Tensor:
262+
raise NotImplementedError("Not yet implemented for HPU backend")
263+
264+
def quantize_blockwise(
265+
self,
266+
A: torch.Tensor,
267+
code: Optional[torch.Tensor] = None,
268+
absmax: Optional[torch.Tensor] = None,
269+
out: Optional[torch.Tensor] = None,
270+
blocksize=4096,
271+
nested=False,
272+
) -> Tuple[torch.Tensor, QuantState]:
273+
raise NotImplementedError("Not yet implemented for HPU backend")
274+
275+
def optimizer_update_8bit_blockwise(
276+
self,
277+
optimizer_name: str,
278+
g: torch.Tensor,
279+
p: torch.Tensor,
280+
state1: torch.Tensor,
281+
state2: Optional[torch.Tensor],
282+
beta1: float,
283+
beta2: float,
284+
eps: float,
285+
step: int,
286+
lr: float,
287+
qmap1: torch.Tensor,
288+
qmap2: Optional[torch.Tensor],
289+
absmax1: torch.Tensor,
290+
absmax2: Optional[torch.Tensor],
291+
weight_decay: float = 0.0,
292+
gnorm_scale: float = 1.0,
293+
skip_zeros=False,
294+
) -> None:
295+
raise NotImplementedError("Not yet implemented for HPU backend")
296+
297+
def optimizer_update_32bit(
298+
self,
299+
optimizer_name: str,
300+
g: torch.Tensor,
301+
p: torch.Tensor,
302+
state1: torch.Tensor,
303+
beta1: float,
304+
eps: float,
305+
step: int,
306+
lr: float,
307+
state2: Optional[torch.Tensor] = None,
308+
beta2: float = 0.0,
309+
weight_decay: float = 0.0,
310+
gnorm_scale: float = 1.0,
311+
unorm_vec: Optional[torch.Tensor] = None,
312+
max_unorm: float = 0.0,
313+
skip_zeros=False,
314+
) -> None:
315+
raise NotImplementedError("Not yet implemented for HPU backend")

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
345345
def to(self, *args, **kwargs):
346346
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
347347

348-
if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized:
348+
if device is not None and device.type in ["cuda", "cpu", "npu", "xpu", "hpu"] and not self.bnb_quantized:
349349
return self._quantize(device)
350350
else:
351351
if self.quant_state is not None:

0 commit comments

Comments
 (0)