Skip to content

Commit 34dd672

Browse files
committed
Add npu support for LLM.int8 forward
1 parent 89373b8 commit 34dd672

File tree

6 files changed

+1868
-22
lines changed

6 files changed

+1868
-22
lines changed

bitsandbytes/backends/npu.py

Lines changed: 149 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from bitsandbytes.cextension import lib
1313
from bitsandbytes.functional import (
14+
COOSparseTensor,
1415
get_4bit_type,
1516
get_ptr,
1617
)
@@ -28,6 +29,43 @@ def assert_on_npu(tensors):
2829
return True
2930

3031

32+
def coo_zeros(rows, cols, rowidx, colidx, values, nnz, device, dtype=torch.half):
33+
rowidx = rowidx.to(torch.int32)
34+
colidx = colidx.to(torch.int32)
35+
values = values.to(device).to(dtype)
36+
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
37+
38+
39+
def row_col_stats(A, threshold):
40+
cols = A.shape[-1]
41+
if len(A.shape) == 3:
42+
rows = A.shape[0] * A.shape[1]
43+
else:
44+
rows = A.shape[0]
45+
46+
row_max = torch.zeros(rows, dtype=torch.float32, device="npu")
47+
col_max = torch.zeros(cols, dtype=torch.float32, device="npu")
48+
outlier_num = torch.zeros(1, dtype=torch.int32, device="npu")
49+
lib.cget_col_row_stats(
50+
get_ptr(A),
51+
get_ptr(row_max),
52+
get_ptr(col_max),
53+
get_ptr(outlier_num),
54+
ct.c_float(threshold),
55+
ct.c_int32(rows),
56+
ct.c_int32(cols),
57+
torch.npu.current_stream()
58+
)
59+
return row_max, col_max, outlier_num
60+
61+
62+
class Int8AB:
63+
def __init__(self, A: torch.Tensor, B: torch.Tensor):
64+
self.A = A
65+
self.B = B
66+
self.device = A.device
67+
68+
3169
class NPUBackend(Backend):
3270
def int8_double_quant(
3371
self,
@@ -38,7 +76,53 @@ def int8_double_quant(
3876
out_row: Optional[torch.Tensor] = None,
3977
threshold=0.0,
4078
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
41-
raise NotImplementedError
79+
past_device = None
80+
device = A.device
81+
assert A.dtype == torch.half
82+
assert device.type == "npu"
83+
84+
cols = A.shape[-1]
85+
if len(A.shape) == 3:
86+
rows = A.shape[0] * A.shape[1]
87+
else:
88+
rows = A.shape[0]
89+
90+
if past_device != str(A.device):
91+
torch.npu.set_device(A.device) # reset context
92+
past_device = str(A.device)
93+
94+
row_stats, col_stats, cnt_npu = row_col_stats(A, threshold)
95+
96+
quant_row = torch.empty((rows, cols), dtype=torch.int8, device=device)
97+
quant_col = torch.empty((rows, cols), dtype=torch.int8, device=device)
98+
outliers_row_idx = torch.zeros(rows, dtype=torch.int32, device=device)
99+
outliers_col_idx = torch.zeros(40 * cols, dtype=torch.int32, device=device) - 1
100+
outliers_value = torch.empty(0, dtype=torch.float16, device=device)
101+
102+
lib.cdouble_rowcol_quant(
103+
get_ptr(A),
104+
get_ptr(row_stats),
105+
get_ptr(col_stats),
106+
get_ptr(quant_row),
107+
get_ptr(quant_col),
108+
get_ptr(outliers_row_idx),
109+
get_ptr(outliers_col_idx),
110+
get_ptr(outliers_value),
111+
ct.c_int(cols),
112+
ct.c_float(threshold),
113+
ct.c_int32(rows),
114+
ct.c_int32(cols),
115+
torch.npu.current_stream()
116+
)
117+
118+
colidx_tmp = torch.unique(outliers_col_idx)
119+
colidx = colidx_tmp[colidx_tmp != -1]
120+
121+
coo_tensor = None
122+
if threshold != 0.0:
123+
coo_tensor = coo_zeros(rows, cols, outliers_row_idx, colidx, outliers_value, cnt_npu, device, dtype=torch.half)
124+
125+
return quant_row, quant_col, row_stats, col_stats, coo_tensor
42126

43127
def int8_vectorwise_dequant(self, A, stats):
44128
return super().int8_vectorwise_dequant(A, stats)
@@ -48,7 +132,35 @@ def int8_vectorwise_quant(
48132
A: torch.Tensor,
49133
threshold=0.0,
50134
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
51-
raise NotImplementedError
135+
device = A.device
136+
assert A.dtype == torch.half
137+
assert device.type == "npu"
138+
139+
cols = A.shape[-1]
140+
if len(A.shape) == 3:
141+
rows = A.shape[0] * A.shape[1]
142+
else:
143+
rows = A.shape[0]
144+
145+
A_no_threshold = None
146+
if threshold > 0.0:
147+
zero = torch.tensor(0.0, dtype=torch.half, device=device)
148+
A_no_threshold = torch.where(A.view(rows, cols).abs() < threshold, A.view(rows, cols), zero)
149+
row_stats = torch.amax(A_no_threshold.abs(), dim=1, keepdim=True).to(device)
150+
out_row = torch.round(A_no_threshold * 127.0 / row_stats).to(torch.int8)
151+
else:
152+
row_stats = torch.amax(A.view(rows, cols).abs(), dim=1, keepdim=True).to(device)
153+
out_row = torch.round(A * 127.0 / row_stats).to(torch.int8)
154+
155+
outlier_cols = None
156+
if threshold > 0.0:
157+
# TODO we could improve perf of this
158+
outliers = A.abs() >= threshold
159+
160+
if outliers.any():
161+
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
162+
163+
return out_row, row_stats, outlier_cols
52164

53165
def transform(
54166
self,
@@ -69,7 +181,7 @@ def int8_linear_matmul(
69181
out: Optional[torch.Tensor] = None,
70182
dtype=torch.int32,
71183
) -> torch.Tensor:
72-
raise NotImplementedError
184+
return Int8AB(A, B)
73185

74186
def int8_mm_dequant(
75187
self,
@@ -79,7 +191,15 @@ def int8_mm_dequant(
79191
out: Optional[torch.Tensor] = None,
80192
bias: Optional[torch.Tensor] = None,
81193
) -> torch.Tensor:
82-
raise NotImplementedError
194+
A, B = A.A, A.B
195+
out = torch_npu.npu_quant_matmul(
196+
A,
197+
B.t(),
198+
scale=col_stats.float() / 127.0,
199+
pertoken_scale=row_stats.float().view(-1) / 127.0,
200+
output_dtype=torch.float16
201+
)
202+
return out
83203

84204
def extract_outliers(
85205
self,
@@ -106,6 +226,10 @@ def quantize_4bit(
106226
if blocksize is None:
107227
blocksize = 128
108228

229+
total_blocks = A.numel() // blocksize
230+
chunks = 8 if A.numel() > 2048 * 2048 else 1
231+
chunksize = (total_blocks + chunks - 1) // chunks
232+
109233
prev_device = torch.npu.current_device()
110234
torch.npu.set_device(A.device)
111235
if A.dtype in [torch.float32, torch.float16, torch.bfloat16]:
@@ -128,12 +252,27 @@ def quantize_4bit(
128252
1.0,
129253
]
130254
data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1)
131-
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
132-
a = A.view(-1, blocksize) / absmax.float()
133-
diff = torch.abs(a.unsqueeze(-1) - data)
134-
out = (torch.argmin(diff, dim=-1) + 8) % 16
135-
out = out.reshape(-1, 2)
136-
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
255+
chunks_absmax = []
256+
chunks_out = []
257+
258+
for i in range(chunks):
259+
start = i * chunksize * blocksize
260+
end = min((i + 1) * chunksize * blocksize, A.numel())
261+
chunk_data = A.view(-1)[start:end].view(-1, blocksize)
262+
263+
absmax = chunk_data.abs().max(dim=1, keepdim=True).values
264+
chunks_absmax.append(absmax)
265+
266+
a = chunk_data / absmax.float()
267+
diff = torch.abs(a.unsqueeze(-1) - data)
268+
out = (torch.argmin(diff, dim=-1) + 8) % 16
269+
270+
out = out.reshape(-1, 2)
271+
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
272+
chunks_out.append(out)
273+
274+
absmax = torch.cat(chunks_absmax, dim=0)
275+
out = torch.cat(chunks_out, dim=0)
137276
else:
138277
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
139278
assert_on_npu([A, absmax, out])

bitsandbytes/nn/modules.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import copy
6+
import importlib
67
from typing import Any, Dict, Optional, TypeVar, Union, overload
78
import warnings
89

@@ -320,9 +321,6 @@ def cpu(self, non_blocking: bool = False):
320321
return self.to(device="cpu", non_blocking=non_blocking)
321322

322323
def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
323-
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
324-
if isinstance(device, int):
325-
device = f"npu:{device}"
326324
return self.to(device="npu" if device is None else device, non_blocking=non_blocking)
327325

328326
def xpu(self, non_blocking: bool = False):
@@ -345,7 +343,10 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
345343
def to(self, *args, **kwargs):
346344
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
347345

348-
if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized:
346+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
347+
if importlib.util.find_spec("torch_npu") and device.type == "cuda" and not self.bnb_quantized:
348+
return self._quantize(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu"))
349+
elif device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized:
349350
return self._quantize(device)
350351
else:
351352
if self.quant_state is not None:
@@ -677,6 +678,19 @@ def xpu(self, device):
677678
self.SCB = SCB
678679
return self
679680

681+
def npu(self, device):
682+
# we store the 8-bit rows-major weight
683+
B = self.data.contiguous().to(torch.float16).npu(device)
684+
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
685+
if CBt is not None:
686+
del CBt
687+
if SCBt is not None:
688+
del SCBt
689+
self.data = CB
690+
self.CB = CB
691+
self.SCB = SCB
692+
return self
693+
680694
@overload
681695
def to(
682696
self: T,
@@ -695,7 +709,10 @@ def to(self, *args, **kwargs):
695709
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
696710

697711
if device is not None:
698-
if device.type == "cuda" and self.data.device.type == "cpu":
712+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
713+
if importlib.util.find_spec("torch_npu") and device.type == "cuda":
714+
return self.npu(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu"))
715+
elif device.type == "cuda" and self.data.device.type == "cpu":
699716
return self.cuda(device)
700717
elif device.type == "cpu":
701718
if self.data.dtype == torch.int8:

0 commit comments

Comments
 (0)