Skip to content

Commit 45ead33

Browse files
Test improvements
1 parent f360a08 commit 45ead33

File tree

5 files changed

+202
-80
lines changed

5 files changed

+202
-80
lines changed

bitsandbytes/_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _(A: torch.Tensor, threshold=0.0):
5959

6060
@register_fake("bitsandbytes::int8_vectorwise_dequant")
6161
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
62-
torch._check(A.dtype == torch.int8, "A must be int8")
62+
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
6363
return torch.empty_like(A, dtype=torch.float32)
6464

6565

@@ -84,7 +84,7 @@ def _(
8484
out: Optional[torch.Tensor] = None,
8585
bias: Optional[torch.Tensor] = None,
8686
) -> torch.Tensor:
87-
torch._check(A.dtype == torch.int32, "A must be int32")
87+
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
8888
return torch.empty_like(A, dtype=torch.float16)
8989

9090

@@ -137,8 +137,8 @@ def _(
137137
) -> Tuple[torch.Tensor, torch.Tensor]:
138138
n = A.numel()
139139
blocks = -(n // -blocksize)
140-
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
141-
out = torch.zeros(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
140+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
141+
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
142142
return out, absmax
143143

144144

bitsandbytes/backends/cpu/ops.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from typing import Optional
1+
import ctypes as ct
2+
from typing import Optional, Tuple
23

34
import torch
45

6+
from bitsandbytes.functional import get_ptr
7+
58
from ..._ops import register_kernel
9+
from ...cextension import lib
610

711

812
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
@@ -12,3 +16,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
1216
if out is not None:
1317
result = out.copy_(result)
1418
return result
19+
20+
21+
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
22+
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
23+
torch._check_is_size(blocksize)
24+
torch._check(A.dtype == torch.float32, "A must be float32")
25+
26+
n = A.numel()
27+
blocks = -(n // -blocksize)
28+
29+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
30+
out = torch.empty_like(A, dtype=torch.uint8)
31+
32+
lib.cquantize_blockwise_cpu_fp32(
33+
get_ptr(code),
34+
get_ptr(A),
35+
get_ptr(absmax),
36+
get_ptr(out),
37+
ct.c_longlong(blocksize),
38+
ct.c_longlong(n),
39+
)
40+
41+
return out, absmax
42+
43+
44+
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
45+
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
46+
torch._check_is_size(blocksize)
47+
torch._check(dtype == torch.float32, "A must be float32 on cpu")
48+
49+
out = torch.empty_like(A, dtype=dtype)
50+
51+
lib.cdequantize_blockwise_cpu_fp32(
52+
get_ptr(code),
53+
get_ptr(A),
54+
get_ptr(absmax),
55+
get_ptr(out),
56+
ct.c_longlong(blocksize),
57+
ct.c_longlong(A.numel()),
58+
)
59+
60+
return out

bitsandbytes/backends/cuda/ops.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
1717
shapeA = A.shape
1818
shapeB = B.shape
1919

20-
torch._check(A.dtype == torch.int8, "B must be int8")
21-
torch._check(B.dtype == torch.int8, "A must be int8")
22-
torch._check(A.ndim == 2, "Only two dimensional matrices are supported for argument B")
23-
torch._check(B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A")
24-
torch._check(prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}")
20+
torch._check(A.dtype == torch.int8, lambda: "B must be int8")
21+
torch._check(B.dtype == torch.int8, lambda: "A must be int8")
22+
torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B")
23+
torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A")
24+
torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}")
2525
torch._check(out is None or out.dtype == dtype)
2626

2727
shapeC = (*shapeB[:-1], shapeA[0])
@@ -34,7 +34,7 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
3434

3535
torch._check(
3636
lda == ldb,
37-
f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}",
37+
lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}",
3838
)
3939

4040
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
@@ -92,10 +92,12 @@ def _(
9292
out: Optional[torch.Tensor] = None,
9393
bias: Optional[torch.Tensor] = None,
9494
) -> torch.Tensor:
95-
torch._check(A.dtype == torch.int32, "A must be int32")
95+
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
96+
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
97+
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
9698

9799
if bias is not None:
98-
torch._check(bias.dtype == torch.float16)
100+
torch._check(bias.dtype == torch.float16, lambda: f"Only fp16 bias is supported, got {bias.dtype}")
99101

100102
if out is None:
101103
out = torch.empty_like(A, dtype=torch.float16)
@@ -118,7 +120,8 @@ def _(
118120

119121
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
120122
def _(A: torch.Tensor, threshold=0.0):
121-
torch._check(A.dtype == torch.float16, "A must be float16")
123+
torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}")
124+
torch._check(threshold >= 0.0, lambda: "threshold must be non-negative")
122125

123126
rows = prod(A.shape[:-1])
124127
cols = A.shape[-1]
@@ -205,12 +208,14 @@ def _get_col_absmax(
205208

206209
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
207210
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
211+
torch._check_is_size(blocksize)
208212
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
213+
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
209214

210215
n = A.numel()
211216
blocks = -(n // -blocksize)
212-
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
213-
out = torch.zeros_like(A, dtype=torch.uint8)
217+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
218+
out = torch.empty_like(A, dtype=torch.uint8)
214219

215220
with _cuda_device_of(A):
216221
args = (
@@ -237,6 +242,10 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor
237242
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
238243
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
239244
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
245+
torch._check(
246+
dtype in [torch.float16, torch.bfloat16, torch.float32],
247+
lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
248+
)
240249

241250
out = torch.empty_like(A, dtype=dtype)
242251

@@ -257,8 +266,6 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
257266
lib.cdequantize_blockwise_bf16(*args)
258267
elif dtype == torch.float32:
259268
lib.cdequantize_blockwise_fp32(*args)
260-
else:
261-
raise ValueError(f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}")
262269

263270
return out
264271

@@ -269,6 +276,10 @@ def _(
269276
) -> Tuple[torch.Tensor, torch.Tensor]:
270277
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
271278
torch._check(quant_type in ["fp4", "nf4"])
279+
torch._check(
280+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
281+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
282+
)
272283

273284
n = A.numel()
274285
blocks = -(n // -blocksize)
@@ -300,8 +311,6 @@ def _(
300311
lib.cquantize_blockwise_fp32_fp4(*args)
301312
else:
302313
lib.cquantize_blockwise_fp32_nf4(*args)
303-
else:
304-
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
305314

306315
return out, absmax
307316

@@ -312,6 +321,10 @@ def _(
312321
) -> torch.Tensor:
313322
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
314323
torch._check(quant_type in ["fp4", "nf4"])
324+
torch._check(
325+
dtype in [torch.bfloat16, torch.float16, torch.float32],
326+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
327+
)
315328

316329
out = torch.empty(shape, dtype=dtype, device=A.device)
317330
n = out.numel()
@@ -344,7 +357,5 @@ def _(
344357
lib.cdequantize_blockwise_fp32_fp4(*args)
345358
else:
346359
lib.cdequantize_blockwise_fp32_nf4(*args)
347-
else:
348-
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
349360

350361
return out

tests/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str:
3636
formatted = "T" if value else "F"
3737
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
3838
formatted = "".join("T" if b else "F" for b in value)
39+
elif isinstance(value, torch.dtype):
40+
formatted = describe_dtype(value)
3941
else:
4042
formatted = str(value)
4143
return f"{label}={formatted}"

0 commit comments

Comments
 (0)