Skip to content

Commit 4ad1d9e

Browse files
Int8 ops updates; tests
1 parent 2813571 commit 4ad1d9e

File tree

5 files changed

+294
-20
lines changed

5 files changed

+294
-20
lines changed

bitsandbytes/_ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,16 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
6363
return torch.empty_like(A, dtype=torch.float32)
6464

6565

66+
# Default PyTorch-native implementation
67+
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
68+
def _(A: torch.Tensor, stats: torch.Tensor):
69+
# To dequantize we divide by 127, or multiply by the reciprocal.
70+
return A * stats.view(-1, 1) * 7.874015718698502e-3
71+
72+
6673
torch.library.define(
6774
"bitsandbytes::int8_mm_dequant",
68-
"(Tensor A, Tensor row_stats, Tensor col_stats, Tensor? out, Tensor? bias) -> Tensor",
75+
"(Tensor A, Tensor row_stats, Tensor col_stats, Tensor? out=None, Tensor? bias=None) -> Tensor",
6976
)
7077

7178

bitsandbytes/backends/cpu/ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from ..._ops import register_kernel
6+
7+
8+
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
9+
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
10+
# Naive implementation: perform matmul in fp32
11+
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
12+
if out is not None:
13+
result = out.copy_(result)
14+
return result

bitsandbytes/backends/cuda/ops.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import ctypes as ct
2+
from math import prod
3+
from typing import Optional, Tuple
4+
5+
import torch
6+
7+
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu
8+
9+
from ..._ops import register_kernel
10+
from ...cextension import lib
11+
12+
13+
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
14+
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
15+
A, B = B, A
16+
17+
shapeA = A.shape
18+
shapeB = B.shape
19+
20+
assert A.dtype == torch.int8
21+
assert B.dtype == torch.int8
22+
assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
23+
assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
24+
assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
25+
assert out is None or out.dtype == dtype
26+
27+
shapeC = (*shapeB[:-1], shapeA[0])
28+
29+
k, m = shapeA
30+
n = prod(shapeB[:-1])
31+
lda = shapeA[-1] # Weights (outputs, inputs)
32+
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
33+
ldc = shapeC[-1] # Output (batch, tokens, outputs)
34+
35+
assert (
36+
lda == ldb
37+
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
38+
39+
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
40+
# We'll fall back to a slower fp32 calculation in this circumstance.
41+
# Fortunately, this should not be very common.
42+
if lda % 4 != 0:
43+
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
44+
if out is not None:
45+
result = out.copy_(result)
46+
return result
47+
48+
if out is None:
49+
out = torch.empty(shapeC, device=A.device, dtype=dtype)
50+
51+
is_on_gpu([A, B, out])
52+
53+
with _cuda_device_of(A):
54+
ctx = CUBLAS_Context.get_instance().get_context(A.device)
55+
ptrA = get_ptr(A)
56+
ptrB = get_ptr(B)
57+
ptrC = get_ptr(out)
58+
ptrRowScale = None
59+
m = ct.c_int32(m)
60+
n = ct.c_int32(n)
61+
k = ct.c_int32(k)
62+
lda = ct.c_int32(lda)
63+
ldb = ct.c_int32(ldb)
64+
ldc = ct.c_int32(ldc)
65+
stream = _get_tensor_stream(A)
66+
67+
if dtype == torch.int32:
68+
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
69+
else:
70+
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
71+
72+
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
73+
raise NotImplementedError("int8_linear_matmul not implemented!")
74+
75+
if has_error:
76+
raise RuntimeError(
77+
f"cublasLt ran into an error!\n"
78+
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
79+
f"\t{(lda, ldb, ldc)=}\n"
80+
f"\t{(m, n, k)=}"
81+
)
82+
83+
return out
84+
85+
86+
@register_kernel("bitsandbytes::int8_mm_dequant", "cuda")
87+
def _(
88+
A: torch.Tensor,
89+
row_stats: torch.Tensor,
90+
col_stats: torch.Tensor,
91+
out: Optional[torch.Tensor] = None,
92+
bias: Optional[torch.Tensor] = None,
93+
) -> torch.Tensor:
94+
assert A.dtype == torch.int32
95+
96+
if bias is not None:
97+
assert bias.dtype == torch.float16
98+
99+
if out is None:
100+
out = torch.empty_like(A, dtype=torch.float16)
101+
102+
ptrA = get_ptr(A)
103+
ptrOut = get_ptr(out)
104+
ptrRowStats = get_ptr(row_stats)
105+
ptrColStats = get_ptr(col_stats)
106+
ptrBias = get_ptr(bias)
107+
numRows = ct.c_int32(prod(A.shape[:-1]))
108+
numCols = ct.c_int32(A.shape[-1])
109+
110+
is_on_gpu([A, row_stats, col_stats, out, bias])
111+
112+
with _cuda_device_of(A):
113+
lib.cdequant_mm_int32_fp16(
114+
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
115+
)
116+
117+
return out
118+
119+
120+
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
121+
def _(A: torch.Tensor, threshold=0.0):
122+
assert A.dtype == torch.half
123+
is_on_gpu([A])
124+
125+
rows = prod(A.shape[:-1])
126+
cols = A.shape[-1]
127+
128+
row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
129+
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
130+
131+
outlier_cols = None
132+
133+
if threshold > 0.0:
134+
# TODO we could improve perf of this
135+
outliers = A.abs() >= threshold
136+
137+
if outliers.any():
138+
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
139+
140+
with _cuda_device_of(A):
141+
lib.cint8_vector_quant(
142+
get_ptr(A),
143+
get_ptr(out_row),
144+
get_ptr(row_stats),
145+
ct.c_float(threshold),
146+
ct.c_int32(rows),
147+
ct.c_int32(cols),
148+
_get_tensor_stream(A),
149+
)
150+
151+
# Zero out values from outlier columns across all rows.
152+
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
153+
if rows > 1 and outlier_cols is not None:
154+
out_row[:, outlier_cols] = 0
155+
156+
return out_row, row_stats, outlier_cols
157+
158+
159+
@register_kernel("bitsandbytes::int8_double_quant", "cuda")
160+
def _(
161+
A: torch.Tensor,
162+
col_stats: Optional[torch.Tensor] = None,
163+
row_stats: Optional[torch.Tensor] = None,
164+
out_col: Optional[torch.Tensor] = None,
165+
out_row: Optional[torch.Tensor] = None,
166+
threshold=0.0,
167+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
168+
# TODO: Optimize/write CUDA kernel for this?
169+
170+
# Use CUDA kernel for rowwise and COO tensor
171+
quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
172+
173+
# PyTorch impl for colwise
174+
col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold)
175+
if threshold > 0.0 and outlier_mask is not None:
176+
A = A.masked_fill(outlier_mask, 0.0)
177+
quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8)
178+
179+
if out_row is not None:
180+
quant_row = out_row.copy_(quant_row)
181+
if out_col is not None:
182+
quant_col = out_col.copy_(quant_col)
183+
184+
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
185+
186+
187+
def _get_col_absmax(
188+
A: torch.Tensor,
189+
threshold=0.0,
190+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
191+
assert A.is_floating_point()
192+
193+
outlier_mask = None
194+
195+
absA = A.abs().view(-1, A.shape[-1])
196+
197+
if threshold > 0.0:
198+
# Filter outliers from stats when enabled
199+
outlier_mask = absA >= threshold
200+
absA.masked_fill_(outlier_mask, 0.0)
201+
202+
# shape [cols]; unsqueeze(0) gives [1,cols]
203+
col_stats = absA.amax(dim=0, keepdim=False).float()
204+
205+
return col_stats, outlier_mask

bitsandbytes/functional.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,24 +2498,7 @@ def int8_double_quant(
24982498
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
24992499
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
25002500
"""
2501-
2502-
# TODO: Optimize/write CUDA kernel for this?
2503-
2504-
# Use CUDA kernel for rowwise and COO tensor
2505-
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)
2506-
2507-
# PyTorch impl for colwise
2508-
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
2509-
if threshold > 0.0 and outlier_mask is not None:
2510-
A = A.masked_fill(outlier_mask, 0.0)
2511-
quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8)
2512-
2513-
if out_row is not None:
2514-
quant_row = out_row.copy_(quant_row)
2515-
if out_col is not None:
2516-
quant_col = out_col.copy_(quant_col)
2517-
2518-
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
2501+
return torch.ops.bitsandbytes.int8_double_quant(A, col_stats, row_stats, out_col, out_row, threshold)
25192502

25202503

25212504
def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
@@ -2529,7 +2512,7 @@ def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
25292512
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
25302513
"""
25312514
# To dequantize we divide by 127, or multiply by the reciprocal.
2532-
return A * stats.view(-1, 1) * 7.874015718698502e-3
2515+
return torch.ops.bitsandbytes.int8_vectorwise_dequant(A, stats)
25332516

25342517

25352518
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):

tests/test_ops.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
import torch
3+
4+
import bitsandbytes # noqa: F401
5+
6+
7+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
8+
def test_int8_linear_matmul(device):
9+
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
10+
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
11+
out = torch.ops.bitsandbytes.int8_linear_matmul(A, B)
12+
13+
assert out.shape == (10, 30)
14+
assert out.dtype == torch.int32
15+
assert out.device == A.device
16+
17+
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul, (A, B))
18+
19+
20+
@pytest.mark.parametrize("threshold", [0.0, 6.0])
21+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
22+
def test_int8_vectorwise_quant(threshold, device):
23+
if device == "cpu":
24+
pytest.skip("CPU implementation is not available")
25+
26+
A = torch.randn(10, 20, dtype=torch.float16, device=device)
27+
A[1][0] = 1000.0
28+
29+
out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
30+
31+
assert out_row.shape == (10, 20)
32+
assert out_row.dtype == torch.int8
33+
assert out_row.device == A.device
34+
assert row_stats.shape == (10,)
35+
assert row_stats.dtype == torch.float32
36+
assert row_stats.device == A.device
37+
38+
if threshold > 0.0:
39+
assert outlier_cols is not None
40+
assert outlier_cols.dim() == 1
41+
assert outlier_cols.shape[0] <= A.shape[1]
42+
assert outlier_cols.device == A.device
43+
else:
44+
assert outlier_cols is None
45+
46+
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
47+
48+
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
49+
50+
51+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
52+
def test_int8_mm_dequant(device):
53+
if device == "cpu":
54+
pytest.skip("CPU implementation is not available")
55+
56+
A = torch.randint(-128, 127, (10, 20), dtype=torch.int32, device=device)
57+
row_stats = torch.randn(10, dtype=torch.float16, device=device)
58+
col_stats = torch.randn(20, dtype=torch.float16, device=device)
59+
out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)
60+
61+
assert out.shape == A.shape
62+
assert out.dtype == torch.float16
63+
assert out.device == A.device
64+
65+
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))

0 commit comments

Comments
 (0)