Skip to content

Commit 116e1cf

Browse files
authored
Merge pull request #2 from bhargaveede/origin/habana-main/hpu-backend
NF4 quantization and dequantization on HPU
2 parents 074ab44 + 622c811 commit 116e1cf

File tree

3 files changed

+111
-32
lines changed

3 files changed

+111
-32
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import subprocess
22
from typing import Optional
33
import warnings
4+
import os
45

56
import torch
67

@@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):
5556

5657
def _maybe_torch_compile(func):
5758
# torch.compile requires g++ and pytorch >= 2.0
58-
if gxx_available and _torch_version_prereq(2, 0):
59+
if gxx_available and _torch_version_prereq(2, 0) and os.getenv('PT_HPU_LAZY_MODE',1)==0:
5960
options = {}
6061
# fx_graph_cache requires pytorch >= 2.2
6162
if _torch_version_prereq(2, 2):
@@ -342,7 +343,7 @@ def quantize_4bit_impl(
342343
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
343344
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
344345
# map [-1, 1] to nf4/fp4
345-
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
346+
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
346347
if quant_type == "nf4":
347348
for i in range(len(NF4_QUANT_TABLE)):
348349
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
@@ -372,7 +373,6 @@ def quantize_4bit_impl(
372373

373374
return out.unsqueeze(0), state
374375

375-
376376
@_maybe_torch_compile
377377
def dequantize_4bit_impl(
378378
A: Tensor,
@@ -452,7 +452,7 @@ def dequantize_4bit_impl(
452452
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
453453
out_uint8[::2] = A.bitwise_and(0xF)
454454
out_uint8[1::2] = A.bitwise_right_shift(4)
455-
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
455+
out_dq = torch.empty(out_uint8.shape, dtype=quant_state.code.dtype, device= quant_state.code.device)
456456
for i in range(len(quant_state.code)):
457457
out_dq[out_uint8 == i] = quant_state.code[i]
458458

bitsandbytes/backends/hpu.py

Lines changed: 106 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .base import Backend
88
from .cpu_xpu_common import (
99
dequantize_4bit_impl,
10+
double_quant_impl,
1011
gemm_4bit_impl,
1112
igemmlt_impl,
1213
mm_dequant_impl,
@@ -16,10 +17,35 @@
1617
Tensor = torch.Tensor
1718

1819

20+
def assert_on_hpu(tensors):
21+
on_hpu = True
22+
for t in tensors:
23+
if t is None:
24+
continue # NULL pointers are fine
25+
on_hpu &= t.device.type == "hpu"
26+
if not on_hpu:
27+
raise TypeError(
28+
"All input tensors need to be on HPU, but found some tensors to not be on HPU:\n"
29+
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
30+
)
31+
return on_hpu
32+
33+
1934
class HPUBackend(Backend):
2035
mm_dequant_compute_dtype = torch.bfloat16
2136
mm_dequant_output_dtype = torch.bfloat16
2237

38+
def double_quant(
39+
self,
40+
A: torch.Tensor,
41+
col_stats: Optional[torch.Tensor] = None,
42+
row_stats: Optional[torch.Tensor] = None,
43+
out_col: Optional[torch.Tensor] = None,
44+
out_row: Optional[torch.Tensor] = None,
45+
threshold=0.0,
46+
):
47+
raise NotImplementedError("Not yet implemented for HPU backend")
48+
2349
def transform(
2450
self,
2551
A: torch.Tensor,
@@ -32,20 +58,10 @@ def transform(
3258
):
3359
"""
3460
Transform tensor A to to_order. It is originally designed for CUDA.
35-
For HPU, it returns the original tensor if transpose=False.
61+
For CPU, it returns the original tensor if transpose=False.
3662
Otherwise, it returns the transpose of A
3763
"""
38-
if transpose:
39-
if out is not None:
40-
out.copy_(A.T)
41-
else:
42-
out = A.T
43-
else:
44-
if out is not None:
45-
out.copy_(A)
46-
else:
47-
out = A
48-
return out, state
64+
raise NotImplementedError("Not yet implemented for HPU backend")
4965

5066
def igemmlt(
5167
self,
@@ -56,9 +72,8 @@ def igemmlt(
5672
out: Optional[torch.Tensor] = None,
5773
Sout: Optional[Tuple[torch.Size, str]] = None,
5874
dtype=torch.int32,
59-
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size,
60-
str]]]]]:
61-
75+
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
76+
assert_on_hpu([A, B])
6277
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
6378

6479
def mm_dequant(
@@ -72,7 +87,7 @@ def mm_dequant(
7287
new_col_stats: Optional[torch.Tensor] = None,
7388
bias: Optional[torch.Tensor] = None,
7489
) -> torch.Tensor:
75-
90+
assert_on_hpu([A, row_stats, col_stats, out, bias])
7691
return mm_dequant_impl(
7792
A,
7893
quant_state,
@@ -95,7 +110,7 @@ def extract_outliers(
95110
"""
96111
Extract columns of A by idx
97112
"""
98-
113+
assert_on_hpu([A])
99114
return A[:, idx].contiguous()
100115

101116
def quantize_4bit(
@@ -108,12 +123,12 @@ def quantize_4bit(
108123
quant_type: Literal["fp4", "nf4"] = "fp4",
109124
quant_storage=torch.uint8,
110125
) -> Tuple[torch.Tensor, QuantState]:
111-
112126
if blocksize is None:
113127
blocksize = 64
114-
assert quant_storage == torch.uint8
115-
return quantize_4bit_impl(
116-
A, absmax, out, blocksize, compress_statistics, quant_type)
128+
129+
assert_on_hpu([A, absmax, out])
130+
assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage"
131+
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
117132

118133
def dequantize_4bit(
119134
self,
@@ -124,9 +139,10 @@ def dequantize_4bit(
124139
blocksize: int = 64,
125140
quant_type: Literal["fp4", "nf4"] = "fp4",
126141
) -> torch.Tensor:
127-
128142
if blocksize is None:
129143
blocksize = 64
144+
145+
assert_on_hpu([A, absmax, out])
130146
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
131147

132148
def gemv_4bit(
@@ -138,10 +154,73 @@ def gemv_4bit(
138154
transposed_B=False,
139155
state: QuantState = None,
140156
) -> torch.Tensor:
141-
157+
assert_on_hpu([A, B, out])
142158
if state is None:
143-
raise ValueError(
144-
"state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
145-
)
159+
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
146160

147-
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
161+
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
162+
163+
def dequantize_blockwise(
164+
self,
165+
A: torch.Tensor,
166+
quant_state: Optional[QuantState] = None,
167+
absmax: Optional[torch.Tensor] = None,
168+
code: Optional[torch.Tensor] = None,
169+
out: Optional[torch.Tensor] = None,
170+
blocksize: int = 4096,
171+
nested=False,
172+
) -> torch.Tensor:
173+
raise NotImplementedError("Not yet implemented for HPU backend")
174+
175+
def quantize_blockwise(
176+
self,
177+
A: torch.Tensor,
178+
code: Optional[torch.Tensor] = None,
179+
absmax: Optional[torch.Tensor] = None,
180+
out: Optional[torch.Tensor] = None,
181+
blocksize=4096,
182+
nested=False,
183+
) -> Tuple[torch.Tensor, QuantState]:
184+
raise NotImplementedError("Not yet implemented for HPU backend")
185+
186+
def optimizer_update_8bit_blockwise(
187+
self,
188+
optimizer_name: str,
189+
g: torch.Tensor,
190+
p: torch.Tensor,
191+
state1: torch.Tensor,
192+
state2: Optional[torch.Tensor],
193+
beta1: float,
194+
beta2: float,
195+
eps: float,
196+
step: int,
197+
lr: float,
198+
qmap1: torch.Tensor,
199+
qmap2: Optional[torch.Tensor],
200+
absmax1: torch.Tensor,
201+
absmax2: Optional[torch.Tensor],
202+
weight_decay: float = 0.0,
203+
gnorm_scale: float = 1.0,
204+
skip_zeros=False,
205+
) -> None:
206+
raise NotImplementedError("Not yet implemented for HPU backend")
207+
208+
def optimizer_update_32bit(
209+
self,
210+
optimizer_name: str,
211+
g: torch.Tensor,
212+
p: torch.Tensor,
213+
state1: torch.Tensor,
214+
beta1: float,
215+
eps: float,
216+
step: int,
217+
lr: float,
218+
state2: Optional[torch.Tensor] = None,
219+
beta2: float = 0.0,
220+
weight_decay: float = 0.0,
221+
gnorm_scale: float = 1.0,
222+
unorm_vec: Optional[torch.Tensor] = None,
223+
max_unorm: float = 0.0,
224+
skip_zeros=False,
225+
) -> None:
226+
raise NotImplementedError("Not yet implemented for HPU backend")

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
331331
def to(self, *args, **kwargs):
332332
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
333333

334-
if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
334+
if device is not None and device.type in ["cuda", "cpu", "hpu"] and not self.bnb_quantized:
335335
return self._quantize(device)
336336
else:
337337
if self.quant_state is not None:

0 commit comments

Comments
 (0)