Skip to content

Commit 10b9d4c

Browse files
Add simple op implementations for CPU (#1602)
* Additional 4bit CPU ops * Additional 4bit CPU ops * Implement additional device-agnostic ops and test updates * More test fixes * int8 tests passing * Fix feature flag for multi_backend
1 parent b7e60ca commit 10b9d4c

File tree

12 files changed

+237
-93
lines changed

12 files changed

+237
-93
lines changed

bitsandbytes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
# This is a signal for integrations with transformers/diffusers.
2323
# Eventually we may remove this but it is currently required for compatibility.
24-
features = {"multi-backend"}
24+
features = {"multi_backend"}
2525
supported_torch_devices = {
2626
"cpu",
2727
"cuda", # NVIDIA/AMD GPU

bitsandbytes/backends/cpu/ops.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
import ctypes as ct
23
from typing import Optional
34

@@ -119,6 +120,10 @@ def _(
119120
) -> tuple[torch.Tensor, torch.Tensor]:
120121
torch._check_is_size(blocksize)
121122
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
123+
torch._check(
124+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
125+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
126+
)
122127

123128
n = A.numel()
124129

@@ -140,3 +145,73 @@ def _(
140145
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
141146

142147
return packed, absmax.float()
148+
149+
150+
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
151+
def _(
152+
A: torch.Tensor,
153+
absmax: torch.Tensor,
154+
blocksize: int,
155+
quant_type: str,
156+
shape: Sequence[int],
157+
dtype: torch.dtype,
158+
) -> torch.Tensor:
159+
torch._check_is_size(blocksize)
160+
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
161+
torch._check(
162+
dtype in [torch.bfloat16, torch.float16, torch.float32],
163+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
164+
)
165+
torch._check(
166+
A.dtype == torch.uint8,
167+
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
168+
)
169+
170+
A = A.view(-1, 1)
171+
172+
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
173+
upper = (A >> 4).to(torch.int64)
174+
lower = (A & 0x0F).to(torch.int64)
175+
176+
# Expand to blocks
177+
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
178+
179+
# Dequantize
180+
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
181+
182+
# Reshape to original shape
183+
blocks = blocks.reshape(-1, *shape[1:])
184+
185+
return blocks.to(dtype)
186+
187+
188+
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
189+
def _(
190+
A: torch.Tensor,
191+
B: torch.Tensor,
192+
shapeB: Sequence[int],
193+
absmax: torch.Tensor,
194+
code: torch.Tensor,
195+
blocksize: int,
196+
) -> torch.Tensor:
197+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
198+
# Right now we assume NF4, as this is the only one supported on CPU.
199+
200+
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
201+
B,
202+
absmax,
203+
blocksize,
204+
"nf4",
205+
shape=shapeB,
206+
dtype=A.dtype,
207+
)
208+
209+
# User called gemv with B.t(), so we need to transpose it back.
210+
# if B.shape[0] == 1:
211+
# B_dq = B_dq.t()
212+
213+
return torch.nn.functional.linear(
214+
A,
215+
B_dq,
216+
bias=None,
217+
)

bitsandbytes/backends/cuda/ops.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2222
_int8_linear_matmul_impl(A, B, out)
2323

2424

25-
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
26-
def _(
27-
A: torch.Tensor,
28-
CA: torch.Tensor,
29-
CB: torch.Tensor,
30-
SCA: torch.Tensor,
31-
SCB: torch.Tensor,
32-
outlier_cols: Optional[torch.Tensor] = None,
33-
bias: Optional[torch.Tensor] = None,
34-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35-
subB = None
36-
37-
if outlier_cols is not None and outlier_cols.numel():
38-
# Extract the inputs with outliers in original precision
39-
subA = A[:, outlier_cols].contiguous()
40-
41-
# Dequantize the corresponding weight columns
42-
subB = (
43-
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
44-
.to(A.dtype)
45-
.t()
46-
)
47-
48-
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
49-
50-
else:
51-
# Needed for torch.compile when there are no outliers.
52-
subA = torch.empty(0, device=A.device, dtype=A.dtype)
53-
54-
# Int8 Matmul + Dequant + Bias
55-
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
56-
57-
if subB is not None:
58-
# Add the outlier columns back to the output
59-
output = output.addmm(subA, subB)
60-
61-
return output, subA
62-
63-
6425
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
6526
A, B = B, A
6627

bitsandbytes/backends/default/ops.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,50 @@
1+
from math import prod
12
from typing import Optional
23

34
import torch
45

56
from ..._ops import register_kernel
67

78

9+
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
10+
def _(
11+
A: torch.Tensor,
12+
CA: torch.Tensor,
13+
CB: torch.Tensor,
14+
SCA: torch.Tensor,
15+
SCB: torch.Tensor,
16+
outlier_cols: Optional[torch.Tensor] = None,
17+
bias: Optional[torch.Tensor] = None,
18+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
19+
subB = None
20+
21+
if outlier_cols is not None and outlier_cols.numel():
22+
# Extract the inputs with outliers in original precision
23+
subA = A[:, outlier_cols].contiguous()
24+
25+
# Dequantize the corresponding weight columns
26+
subB = (
27+
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
28+
.to(A.dtype)
29+
.t()
30+
)
31+
32+
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
33+
34+
else:
35+
# Needed for torch.compile when there are no outliers.
36+
subA = torch.empty(0, device=A.device, dtype=A.dtype)
37+
38+
# Int8 Matmul + Dequant + Bias
39+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
40+
41+
if subB is not None:
42+
# Add the outlier columns back to the output
43+
output = output.addmm(subA, subB)
44+
45+
return output, subA
46+
47+
848
@register_kernel("bitsandbytes::int8_scaled_mm", "default")
949
def _(
1050
A: torch.Tensor,
@@ -41,3 +81,41 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor
4181
if out is not None:
4282
result = out.copy_(result)
4383
return result
84+
85+
86+
@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
87+
def _(A: torch.Tensor, threshold=0.0):
88+
rows = prod(A.shape[:-1])
89+
outlier_cols = None
90+
91+
outlier_restore = None
92+
93+
if threshold > 0.0:
94+
outliers = A.abs() >= threshold
95+
96+
if outliers.any():
97+
# Determine which columns contain outliers, and zero out the
98+
# outliers ahead of quantization. We need to keep a backup of these
99+
# outliers to restore them after quantization.
100+
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
101+
outlier_restore = A[outliers].clone()
102+
A[outliers] = 0
103+
else:
104+
# Needed for torch.compile support.
105+
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
106+
107+
# Get absmax for each row.
108+
row_stats = torch.max(A.abs(), dim=1).values.float()
109+
110+
# Quantize row-wise to int8.
111+
out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)
112+
113+
# Zero out values from outlier columns across all rows.
114+
if rows > 1 and outlier_cols is not None:
115+
out_row[:, outlier_cols] = 0
116+
117+
# Restore outliers.
118+
if outlier_restore is not None:
119+
A[outliers] = outlier_restore
120+
121+
return out_row, row_stats, outlier_cols

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def quantize_blockwise(
779779
state2=state2,
780780
)
781781
else:
782-
quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
782+
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
783783

784784
# TODO(matthewdouglas): Deprecate out kwarg
785785
out = out.copy_(_out) if out is not None else _out

bitsandbytes/nn/modules.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -592,19 +592,28 @@ def __new__(
592592
obj.has_fp16_weights = has_fp16_weights
593593
return obj
594594

595-
def cuda(self, device):
595+
def _quantize(self, device):
596596
if self.has_fp16_weights:
597-
return super().cuda(device)
598-
else:
599-
# We quantize the weight and store in 8bit row-major
600-
B = self.data.contiguous().half().cuda(device)
601-
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
602-
self.data = CB
603-
self.CB = CB
604-
self.SCB = SCB
597+
return super().to(device)
598+
599+
# We quantize the weight and store in 8bit row-major
600+
B = self.data.contiguous().to(device=device, dtype=torch.float16)
601+
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
602+
self.data = CB
603+
self.CB = CB
604+
self.SCB = SCB
605605

606606
return self
607607

608+
def cpu(self):
609+
return self.to(device="cpu")
610+
611+
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
612+
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
613+
614+
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
615+
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
616+
608617
def __deepcopy__(self, memo):
609618
# adjust this if new arguments are added to the constructor
610619
new_instance = type(self).__new__(
@@ -634,8 +643,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
634643
def to(self, *args, **kwargs):
635644
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
636645

637-
if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
638-
return self.cuda(device)
646+
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
647+
return self._quantize(device)
639648
else:
640649
new_param = Int8Params(
641650
super().to(device=device, dtype=dtype, non_blocking=non_blocking),

tests/test_autograd.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@
3232
def test_matmullt(
3333
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
3434
):
35-
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
36-
# TODO: Deprecate/remove?
37-
pytest.skip("switchback_bnb only works on CUDA.")
35+
if device != "cuda":
36+
if funcs[1] == bnb.research.switchback_bnb:
37+
# TODO: Deprecate/remove?
38+
pytest.skip("switchback_bnb only works on CUDA.")
39+
40+
if req_grad[1]:
41+
# This will be deprecated for CUDA in the future. We don't expect
42+
# this to work on any other device.
43+
pytest.skip("Deprecated feature with CUDA support only.")
3844

3945
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
4046
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
@@ -171,7 +177,7 @@ def test_matmul_4bit(
171177
quant_type,
172178
):
173179
if device == "cpu" and quant_type == "fp4":
174-
pytest.skip("Only nf4 is supported on CPU")
180+
pytest.xfail("Only nf4 is supported on CPU")
175181

176182
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
177183
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)

tests/test_functional.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_few_bit_quant(self, device, bits, method):
186186
code = F.create_dynamic_map(True, bits - 0, bits).to(device)
187187
elif method == "quantile":
188188
if device != "cuda":
189-
pytest.xfail("Quantile map only works on CUDA")
189+
pytest.skip("Quantile map only works on CUDA")
190190
values = torch.randn(2048, 2048, device="cuda")
191191
code = F.create_quantile_map(values, bits).cuda()
192192
# for some data types we have no zero
@@ -593,7 +593,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims):
593593

594594
A = A.view(-1, A.shape[-1])
595595

596-
CA, _, statsA, _, _ = F.int8_double_quant(A)
596+
CA, statsA, _ = F.int8_vectorwise_quant(A)
597597
CB, statsB, _ = F.int8_vectorwise_quant(B)
598598
output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
599599

@@ -1102,6 +1102,9 @@ class TestQuantize4BitFunctional:
11021102
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11031103
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
11041104
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1105+
if device == "cpu" and quant_type != "nf4":
1106+
pytest.xfail("fp4 quantization is not supported on CPU")
1107+
11051108
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11061109
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11071110
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
@@ -1134,6 +1137,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11341137
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11351138
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
11361139
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
1140+
if device == "cpu" and quant_type != "nf4":
1141+
pytest.xfail("fp4 quantization is not supported on CPU")
1142+
11371143
errs1 = []
11381144
errs2 = []
11391145
for i in range(10):
@@ -1206,6 +1212,12 @@ def test_bench_4bit_dequant(self, quant_type):
12061212
)
12071213
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
12081214
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
1215+
if device == "cpu":
1216+
if storage_type != "nf4":
1217+
pytest.xfail("fp4 quantization is not supported on CPU")
1218+
if quant_storage != torch.uint8:
1219+
pytest.xfail("Only uint8 storage is supported on CPU")
1220+
12091221
errs1 = []
12101222
errs2 = []
12111223
errs3 = []
@@ -1216,7 +1228,11 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
12161228
max_errs2 = []
12171229
max_errs3 = []
12181230

1219-
for i in range(100):
1231+
# Large number of iterations is excessive and slow on CPU.
1232+
# Keep for CUDA for now.
1233+
iters = 100 if device == "cuda" else 10
1234+
1235+
for i in range(iters):
12201236
if kind == "fc1":
12211237
A = torch.randn(1, dim, dtype=dtype, device=device)
12221238
B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
@@ -1337,6 +1353,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
13371353
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
13381354
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
13391355
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
1356+
if device == "cpu" and storage_type != "nf4":
1357+
pytest.xfail("fp4 quantization is not supported on CPU")
1358+
13401359
dims = 10
13411360
torch.random.manual_seed(np.random.randint(0, 412424242))
13421361
dims = get_test_dims(0, 8192, n=dims)

0 commit comments

Comments
 (0)