Skip to content

Commit e9c79cf

Browse files
Implement 4bit quant/dequant ops
1 parent 4ad1d9e commit e9c79cf

File tree

3 files changed

+275
-206
lines changed

3 files changed

+275
-206
lines changed

bitsandbytes/_ops.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import prod
2-
from typing import Optional, Tuple
2+
from typing import Optional, Sequence, Tuple
33

44
import torch
55

@@ -37,7 +37,7 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
3737

3838
torch.library.define(
3939
"bitsandbytes::int8_vectorwise_quant",
40-
"(Tensor A, Scalar threshold=0.0) -> (Tensor, Tensor, Tensor?)",
40+
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
4141
)
4242

4343

@@ -90,7 +90,7 @@ def _(
9090

9191
torch.library.define(
9292
"bitsandbytes::int8_double_quant",
93-
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, Scalar threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
93+
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
9494
)
9595

9696

@@ -110,3 +110,53 @@ def _(
110110
outlier_n = torch.library.get_ctx().new_dynamic_size()
111111
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
112112
return out_row, out_col, row_stats, col_stats, outlier_cols
113+
114+
115+
torch.library.define(
116+
"bitsandbytes::dequantize_4bit",
117+
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
118+
)
119+
120+
121+
@register_fake("bitsandbytes::dequantize_4bit")
122+
def _(
123+
A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype
124+
) -> torch.Tensor:
125+
return torch.empty(shape, dtype=dtype, device=A.device)
126+
127+
128+
torch.library.define(
129+
"bitsandbytes::quantize_4bit",
130+
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
131+
)
132+
133+
134+
def _(
135+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
136+
) -> Tuple[torch.Tensor, torch.Tensor]:
137+
n = A.numel()
138+
blocks = -(n // -blocksize)
139+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
140+
out = torch.zeros(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
141+
return out, absmax
142+
143+
144+
torch.library.define(
145+
"bitsandbytes::dequantize_blockwise",
146+
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
147+
)
148+
149+
150+
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
151+
return torch.empty_like(A, dtype=dtype)
152+
153+
154+
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
155+
156+
157+
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
158+
n = A.numel()
159+
blocks = -(n // -blocksize)
160+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
161+
out = torch.zeros_like(A, dtype=torch.uint8)
162+
return out, absmax

bitsandbytes/backends/cuda/ops.py

Lines changed: 174 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import ctypes as ct
22
from math import prod
3-
from typing import Optional, Tuple
3+
from typing import Optional, Sequence, Tuple
44

55
import torch
66

7-
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu
7+
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
88

99
from ..._ops import register_kernel
1010
from ...cextension import lib
@@ -17,12 +17,12 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
1717
shapeA = A.shape
1818
shapeB = B.shape
1919

20-
assert A.dtype == torch.int8
21-
assert B.dtype == torch.int8
22-
assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
23-
assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
24-
assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
25-
assert out is None or out.dtype == dtype
20+
torch._check(A.dtype == torch.int8, "B must be int8")
21+
torch._check(B.dtype == torch.int8, "A must be int8")
22+
torch._check(A.ndim == 2, "Only two dimensional matrices are supported for argument B")
23+
torch._check(B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A")
24+
torch._check(prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}")
25+
torch._check(out is None or out.dtype == dtype)
2626

2727
shapeC = (*shapeB[:-1], shapeA[0])
2828

@@ -32,9 +32,10 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
3232
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
3333
ldc = shapeC[-1] # Output (batch, tokens, outputs)
3434

35-
assert (
36-
lda == ldb
37-
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
35+
torch._check(
36+
lda == ldb,
37+
f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}",
38+
)
3839

3940
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
4041
# We'll fall back to a slower fp32 calculation in this circumstance.
@@ -48,8 +49,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
4849
if out is None:
4950
out = torch.empty(shapeC, device=A.device, dtype=dtype)
5051

51-
is_on_gpu([A, B, out])
52-
5352
with _cuda_device_of(A):
5453
ctx = CUBLAS_Context.get_instance().get_context(A.device)
5554
ptrA = get_ptr(A)
@@ -69,16 +68,18 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
6968
else:
7069
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
7170

72-
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
73-
raise NotImplementedError("int8_linear_matmul not implemented!")
74-
7571
if has_error:
76-
raise RuntimeError(
77-
f"cublasLt ran into an error!\n"
78-
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
79-
f"\t{(lda, ldb, ldc)=}\n"
80-
f"\t{(m, n, k)=}"
81-
)
72+
if has_error == 100:
73+
# `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
74+
# TODO: Warn and implement a fallback to fp32 compute?
75+
raise NotImplementedError("int8_linear_matmul not implemented!")
76+
else:
77+
raise RuntimeError(
78+
f"cublasLt ran into an error!\n"
79+
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
80+
f"\t{(lda, ldb, ldc)=}\n"
81+
f"\t{(m, n, k)=}"
82+
)
8283

8384
return out
8485

@@ -91,10 +92,10 @@ def _(
9192
out: Optional[torch.Tensor] = None,
9293
bias: Optional[torch.Tensor] = None,
9394
) -> torch.Tensor:
94-
assert A.dtype == torch.int32
95+
torch._check(A.dtype == torch.int32, "A must be int32")
9596

9697
if bias is not None:
97-
assert bias.dtype == torch.float16
98+
torch._check(bias.dtype == torch.float16)
9899

99100
if out is None:
100101
out = torch.empty_like(A, dtype=torch.float16)
@@ -107,8 +108,6 @@ def _(
107108
numRows = ct.c_int32(prod(A.shape[:-1]))
108109
numCols = ct.c_int32(A.shape[-1])
109110

110-
is_on_gpu([A, row_stats, col_stats, out, bias])
111-
112111
with _cuda_device_of(A):
113112
lib.cdequant_mm_int32_fp16(
114113
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
@@ -119,8 +118,7 @@ def _(
119118

120119
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
121120
def _(A: torch.Tensor, threshold=0.0):
122-
assert A.dtype == torch.half
123-
is_on_gpu([A])
121+
torch._check(A.dtype == torch.float16, "A must be float16")
124122

125123
rows = prod(A.shape[:-1])
126124
cols = A.shape[-1]
@@ -188,7 +186,7 @@ def _get_col_absmax(
188186
A: torch.Tensor,
189187
threshold=0.0,
190188
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
191-
assert A.is_floating_point()
189+
torch._check(A.is_floating_point())
192190

193191
outlier_mask = None
194192

@@ -203,3 +201,150 @@ def _get_col_absmax(
203201
col_stats = absA.amax(dim=0, keepdim=False).float()
204202

205203
return col_stats, outlier_mask
204+
205+
206+
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
207+
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
208+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
209+
210+
n = A.numel()
211+
blocks = -(n // -blocksize)
212+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
213+
out = torch.zeros_like(A, dtype=torch.uint8)
214+
215+
with _cuda_device_of(A):
216+
args = (
217+
get_ptr(code),
218+
get_ptr(A),
219+
get_ptr(absmax),
220+
get_ptr(out),
221+
ct.c_int32(blocksize),
222+
ct.c_int(A.numel()),
223+
)
224+
225+
if A.dtype == torch.float16:
226+
lib.cquantize_blockwise_fp16(*args)
227+
elif A.dtype == torch.bfloat16:
228+
lib.cquantize_blockwise_bf16(*args)
229+
elif A.dtype == torch.float32:
230+
lib.cquantize_blockwise_fp32(*args)
231+
else:
232+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
233+
234+
return out, absmax
235+
236+
237+
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
238+
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
239+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
240+
241+
out = torch.empty_like(A, dtype=dtype)
242+
243+
with _cuda_device_of(A):
244+
args = (
245+
get_ptr(code),
246+
get_ptr(A),
247+
get_ptr(absmax),
248+
get_ptr(out),
249+
ct.c_int(blocksize),
250+
ct.c_int(A.numel()),
251+
_get_tensor_stream(A),
252+
)
253+
254+
if dtype == torch.float16:
255+
lib.cdequantize_blockwise_fp16(*args)
256+
elif dtype == torch.bfloat16:
257+
lib.cdequantize_blockwise_bf16(*args)
258+
elif dtype == torch.float32:
259+
lib.cdequantize_blockwise_fp32(*args)
260+
else:
261+
raise ValueError(f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}")
262+
263+
return out
264+
265+
266+
@register_kernel("bitsandbytes::quantize_4bit", "cuda")
267+
def _(
268+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
269+
) -> Tuple[torch.Tensor, torch.Tensor]:
270+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
271+
torch._check(quant_type in ["fp4", "nf4"])
272+
273+
n = A.numel()
274+
blocks = -(n // -blocksize)
275+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
276+
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
277+
278+
with _cuda_device_of(A):
279+
args = (
280+
None,
281+
get_ptr(A),
282+
get_ptr(absmax),
283+
get_ptr(out),
284+
ct.c_int32(blocksize),
285+
ct.c_int(n),
286+
)
287+
288+
if A.dtype == torch.bfloat16:
289+
if quant_type == "fp4":
290+
lib.cquantize_blockwise_bf16_fp4(*args)
291+
else:
292+
lib.cquantize_blockwise_bf16_nf4(*args)
293+
elif A.dtype == torch.float16:
294+
if quant_type == "fp4":
295+
lib.cquantize_blockwise_fp16_fp4(*args)
296+
else:
297+
lib.cquantize_blockwise_fp16_nf4(*args)
298+
elif A.dtype == torch.float32:
299+
if quant_type == "fp4":
300+
lib.cquantize_blockwise_fp32_fp4(*args)
301+
else:
302+
lib.cquantize_blockwise_fp32_nf4(*args)
303+
else:
304+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
305+
306+
return out, absmax
307+
308+
309+
@register_kernel("bitsandbytes::dequantize_4bit", "cuda")
310+
def _(
311+
A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype
312+
) -> torch.Tensor:
313+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
314+
torch._check(quant_type in ["fp4", "nf4"])
315+
316+
out = torch.empty(shape, dtype=dtype, device=A.device)
317+
n = out.numel()
318+
319+
stream = _get_tensor_stream(A)
320+
321+
with _cuda_device_of(A):
322+
args = (
323+
None,
324+
get_ptr(A),
325+
get_ptr(absmax),
326+
get_ptr(out),
327+
ct.c_int(blocksize),
328+
ct.c_int(n),
329+
stream,
330+
)
331+
332+
if out.dtype == torch.bfloat16:
333+
if quant_type == "fp4":
334+
lib.cdequantize_blockwise_bf16_fp4(*args)
335+
else:
336+
lib.cdequantize_blockwise_bf16_nf4(*args)
337+
elif out.dtype == torch.float16:
338+
if quant_type == "fp4":
339+
lib.cdequantize_blockwise_fp16_fp4(*args)
340+
else:
341+
lib.cdequantize_blockwise_fp16_nf4(*args)
342+
elif out.dtype == torch.float32:
343+
if quant_type == "fp4":
344+
lib.cdequantize_blockwise_fp32_fp4(*args)
345+
else:
346+
lib.cdequantize_blockwise_fp32_nf4(*args)
347+
else:
348+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
349+
350+
return out

0 commit comments

Comments
 (0)