Skip to content

Commit 04482ff

Browse files
Initial int8 op registration
1 parent d5df4c6 commit 04482ff

File tree

3 files changed

+67
-241
lines changed

3 files changed

+67
-241
lines changed

bitsandbytes/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from . import research, utils
6+
7+
from . import _ops, research, utils
78
from .autograd._functions import (
89
MatmulLtState,
910
bmm_cublas,
@@ -12,6 +13,8 @@
1213
matmul_cublas,
1314
mm_cublas,
1415
)
16+
from .backends.cpu import ops as cpu_ops
17+
from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only
1518
from .nn import modules
1619
from .optim import adam
1720

bitsandbytes/_ops.py

Lines changed: 60 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import ctypes as ct
21
from math import prod
3-
from typing import Optional
2+
from typing import Optional, Tuple
43

54
import torch
65

7-
from .cextension import lib
8-
from .functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu
9-
106
_IS_TORCH_GTE_24 = False
117

128
if hasattr(torch.library, "register_fake"):
@@ -27,11 +23,10 @@
2723
# return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
2824
torch.library.define(
2925
"bitsandbytes::int8_linear_matmul",
30-
"(Tensor A, Tensor B, Tensor(a!)? out=None, ScalarType dtype=int32) -> Tensor(a!)",
26+
"(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32) -> Tensor",
3127
)
3228

3329

34-
# Fake/abstract op
3530
@register_fake("bitsandbytes::int8_linear_matmul")
3631
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
3732
shapeC = (*A.shape[:-1], B.shape[0])
@@ -40,103 +35,71 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
4035
return out
4136

4237

43-
# CPU implementation
44-
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
45-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
46-
# Naive implementation: perform matmul in fp32
47-
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
48-
if out is not None:
49-
result = out.copy_(result)
50-
return result
38+
torch.library.define(
39+
"bitsandbytes::int8_vectorwise_quant",
40+
"(Tensor A, Scalar threshold=0.0) -> (Tensor, Tensor, Tensor?)",
41+
)
5142

5243

53-
# MPS impl
54-
@register_kernel("bitsandbytes::int8_linear_matmul", "mps")
55-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
56-
pass
44+
@register_fake("bitsandbytes::int8_vectorwise_quant")
45+
def _(A: torch.Tensor, threshold=0.0):
46+
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
47+
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
5748

49+
if threshold == 0.0:
50+
return out_row, row_stats, None
5851

59-
# XPU impl
60-
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
61-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
62-
pass
52+
outlier_cols = torch.library.get_ctx().new_dynamic_size()
6353

54+
return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)
6455

65-
# Ascend NPU impl
66-
@register_kernel("bitsandbytes::int8_linear_matmul", "npu")
67-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
68-
pass
6956

57+
torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")
7058

71-
# CUDA/ROCm impl
72-
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
73-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
74-
A, B = B, A
75-
76-
shapeA = A.shape
77-
shapeB = B.shape
78-
79-
assert A.dtype == torch.int8
80-
assert B.dtype == torch.int8
81-
assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
82-
assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
83-
assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
84-
assert out is None or out.dtype == dtype
85-
86-
shapeC = (*shapeB[:-1], shapeA[0])
87-
88-
k, m = shapeA
89-
n = prod(shapeB[:-1])
90-
lda = shapeA[-1] # Weights (outputs, inputs)
91-
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
92-
ldc = shapeC[-1] # Output (batch, tokens, outputs)
93-
94-
assert (
95-
lda == ldb
96-
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
97-
98-
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
99-
# We'll fall back to a slower fp32 calculation in this circumstance.
100-
# Fortunately, this should not be very common.
101-
if lda % 4 != 0:
102-
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
103-
if out is not None:
104-
result = out.copy_(result)
105-
return result
10659

107-
if out is None:
108-
out = torch.empty(shapeC, device=A.device, dtype=dtype)
109-
110-
is_on_gpu([A, B, out])
111-
112-
with _cuda_device_of(A):
113-
ctx = CUBLAS_Context.get_instance().get_context(A.device)
114-
ptrA = get_ptr(A)
115-
ptrB = get_ptr(B)
116-
ptrC = get_ptr(out)
117-
ptrRowScale = None
118-
m = ct.c_int32(m)
119-
n = ct.c_int32(n)
120-
k = ct.c_int32(k)
121-
lda = ct.c_int32(lda)
122-
ldb = ct.c_int32(ldb)
123-
ldc = ct.c_int32(ldc)
124-
stream = _get_tensor_stream(A)
125-
126-
if dtype == torch.int32:
127-
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
128-
else:
129-
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
130-
131-
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
132-
raise NotImplementedError("int8_linear_matmul not implemented!")
133-
134-
if has_error:
135-
raise RuntimeError(
136-
f"cublasLt ran into an error!\n"
137-
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
138-
f"\t{(lda, ldb, ldc)=}\n"
139-
f"\t{(m, n, k)=}"
140-
)
60+
@register_fake("bitsandbytes::int8_vectorwise_dequant")
61+
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
62+
torch._check(A.dtype == torch.int8, "A must be int8")
63+
return torch.empty_like(A, dtype=torch.float32)
14164

142-
return out
65+
66+
torch.library.define(
67+
"bitsandbytes::int8_mm_dequant",
68+
"(Tensor A, Tensor row_stats, Tensor col_stats, Tensor? out, Tensor? bias) -> Tensor",
69+
)
70+
71+
72+
@register_fake("bitsandbytes::int8_mm_dequant")
73+
def _(
74+
A: torch.Tensor,
75+
row_stats: torch.Tensor,
76+
col_stats: torch.Tensor,
77+
out: Optional[torch.Tensor] = None,
78+
bias: Optional[torch.Tensor] = None,
79+
) -> torch.Tensor:
80+
torch._check(A.dtype == torch.int32, "A must be int32")
81+
return torch.empty_like(A, dtype=torch.float16)
82+
83+
84+
torch.library.define(
85+
"bitsandbytes::int8_double_quant",
86+
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, Scalar threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
87+
)
88+
89+
90+
@register_fake("bitsandbytes::int8_double_quant")
91+
def _(
92+
A: torch.Tensor,
93+
col_stats: Optional[torch.Tensor] = None,
94+
row_stats: Optional[torch.Tensor] = None,
95+
out_col: Optional[torch.Tensor] = None,
96+
out_row: Optional[torch.Tensor] = None,
97+
threshold=0.0,
98+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
99+
out_row = torch.empty_like(A, dtype=torch.int8)
100+
out_col = torch.empty_like(A, dtype=torch.int8)
101+
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
102+
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
103+
outlier_n = torch.library.get_ctx().new_dynamic_size()
104+
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
105+
return out_row, out_col, row_stats, col_stats, outlier_cols

bitsandbytes/functional.py

Lines changed: 3 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,88 +2291,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
22912291
Returns:
22922292
`torch.Tensor`: The result of the operation.
22932293
"""
2294-
2295-
#
2296-
# To use the IMMA tensor core kernels without special Turing/Ampere layouts,
2297-
# cublasLt has some rules, namely: A must be transposed, B must not be transposed.
2298-
# The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major.
2299-
# This will typically be used with row-major tensors to efficiently
2300-
# calculate the linear layer with `C = B @ A.T` without any transformations.
2301-
# We will swap A and B in the API invocation, so that we get `C = A @ B.T`.
2302-
#
2303-
# Quick explanation:
2304-
# With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`.
2305-
# To get row-major output, `C.T = (A @ B.T).T = B @ A.T`.
2306-
#
2307-
A, B = B, A
2308-
2309-
shapeA = A.shape
2310-
shapeB = B.shape
2311-
2312-
assert A.dtype == torch.int8
2313-
assert B.dtype == torch.int8
2314-
assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
2315-
assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
2316-
assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
2317-
assert out is None or out.dtype == dtype
2318-
2319-
shapeC = (*shapeB[:-1], shapeA[0])
2320-
2321-
k, m = shapeA
2322-
n = prod(shapeB[:-1])
2323-
lda = shapeA[-1] # Weights (outputs, inputs)
2324-
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
2325-
ldc = shapeC[-1] # Output (batch, tokens, outputs)
2326-
2327-
assert (
2328-
lda == ldb
2329-
), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
2330-
2331-
# cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
2332-
# We'll fall back to a slower fp32 calculation in this circumstance.
2333-
# Fortunately, this should not be very common.
2334-
if lda % 4 != 0:
2335-
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
2336-
if out is not None:
2337-
result = out.copy_(result)
2338-
return result
2339-
2340-
if out is None:
2341-
out = torch.empty(shapeC, device=A.device, dtype=dtype)
2342-
2343-
is_on_gpu([A, B, out])
2344-
2345-
with _cuda_device_of(A):
2346-
ctx = CUBLAS_Context.get_instance().get_context(A.device)
2347-
ptrA = get_ptr(A)
2348-
ptrB = get_ptr(B)
2349-
ptrC = get_ptr(out)
2350-
ptrRowScale = None
2351-
m = ct.c_int32(m)
2352-
n = ct.c_int32(n)
2353-
k = ct.c_int32(k)
2354-
lda = ct.c_int32(lda)
2355-
ldb = ct.c_int32(ldb)
2356-
ldc = ct.c_int32(ldc)
2357-
stream = _get_tensor_stream(A)
2358-
2359-
if dtype == torch.int32:
2360-
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
2361-
else:
2362-
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
2363-
2364-
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
2365-
raise NotImplementedError("int8_linear_matmul not implemented!")
2366-
2367-
if has_error:
2368-
raise RuntimeError(
2369-
f"cublasLt ran into an error!\n"
2370-
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
2371-
f"\t{(lda, ldb, ldc)=}\n"
2372-
f"\t{(m, n, k)=}"
2373-
)
2374-
2375-
return out
2294+
return torch.ops.bitsandbytes.int8_linear_matmul(A, B, out, dtype)
23762295

23772296

23782297
def int8_mm_dequant(
@@ -2394,31 +2313,7 @@ def int8_mm_dequant(
23942313
Returns:
23952314
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
23962315
"""
2397-
2398-
assert A.dtype == torch.int32
2399-
2400-
if bias is not None:
2401-
assert bias.dtype == torch.float16
2402-
2403-
if out is None:
2404-
out = torch.empty_like(A, dtype=torch.float16)
2405-
2406-
ptrA = get_ptr(A)
2407-
ptrOut = get_ptr(out)
2408-
ptrRowStats = get_ptr(row_stats)
2409-
ptrColStats = get_ptr(col_stats)
2410-
ptrBias = get_ptr(bias)
2411-
numRows = ct.c_int32(prod(A.shape[:-1]))
2412-
numCols = ct.c_int32(A.shape[-1])
2413-
2414-
is_on_gpu([A, row_stats, col_stats, out, bias])
2415-
2416-
with _cuda_device_of(A):
2417-
lib.cdequant_mm_int32_fp16(
2418-
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
2419-
)
2420-
2421-
return out
2316+
return torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats, out, bias)
24222317

24232318

24242319
@deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning)
@@ -2766,42 +2661,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
27662661
- `torch.Tensor` with dtype `torch.float32`: The quantization scales.
27672662
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
27682663
"""
2769-
2770-
assert A.dtype == torch.half
2771-
is_on_gpu([A])
2772-
2773-
rows = prod(A.shape[:-1])
2774-
cols = A.shape[-1]
2775-
2776-
row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
2777-
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
2778-
2779-
outlier_cols = None
2780-
2781-
if threshold > 0.0:
2782-
# TODO we could improve perf of this
2783-
outliers = A.abs() >= threshold
2784-
2785-
if outliers.any():
2786-
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
2787-
2788-
with _cuda_device_of(A):
2789-
lib.cint8_vector_quant(
2790-
get_ptr(A),
2791-
get_ptr(out_row),
2792-
get_ptr(row_stats),
2793-
ct.c_float(threshold),
2794-
ct.c_int32(rows),
2795-
ct.c_int32(cols),
2796-
_get_tensor_stream(A),
2797-
)
2798-
2799-
# Zero out values from outlier columns across all rows.
2800-
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
2801-
if rows > 1 and outlier_cols is not None:
2802-
out_row[:, outlier_cols] = 0
2803-
2804-
return out_row, row_stats, outlier_cols
2664+
return torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold)
28052665

28062666

28072667
@deprecated(

0 commit comments

Comments
 (0)