Skip to content

Commit 0410ec1

Browse files
Implement additional device-agnostic ops and test updates
1 parent d8bd0b3 commit 0410ec1

File tree

10 files changed

+178
-91
lines changed

10 files changed

+178
-91
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def _(
167167
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
168168
)
169169

170+
A = A.view(-1, 1)
171+
170172
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
171173
upper = (A >> 4).to(torch.int64)
172174
lower = (A & 0x0F).to(torch.int64)
@@ -181,3 +183,35 @@ def _(
181183
blocks = blocks.reshape(-1, *shape[1:])
182184

183185
return blocks.to(dtype)
186+
187+
188+
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
189+
def _(
190+
A: torch.Tensor,
191+
B: torch.Tensor,
192+
shapeB: Sequence[int],
193+
absmax: torch.Tensor,
194+
code: torch.Tensor,
195+
blocksize: int,
196+
) -> torch.Tensor:
197+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
198+
# Right now we assume NF4, as this is the only one supported on CPU.
199+
200+
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
201+
B,
202+
absmax,
203+
blocksize,
204+
"nf4",
205+
shape=shapeB,
206+
dtype=A.dtype,
207+
)
208+
209+
# User called gemv with B.t(), so we need to transpose it back.
210+
# if B.shape[0] == 1:
211+
# B_dq = B_dq.t()
212+
213+
return torch.nn.functional.linear(
214+
A,
215+
B_dq,
216+
bias=None,
217+
)

bitsandbytes/backends/cuda/ops.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2222
_int8_linear_matmul_impl(A, B, out)
2323

2424

25-
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
26-
def _(
27-
A: torch.Tensor,
28-
CA: torch.Tensor,
29-
CB: torch.Tensor,
30-
SCA: torch.Tensor,
31-
SCB: torch.Tensor,
32-
outlier_cols: Optional[torch.Tensor] = None,
33-
bias: Optional[torch.Tensor] = None,
34-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35-
subB = None
36-
37-
if outlier_cols is not None and outlier_cols.numel():
38-
# Extract the inputs with outliers in original precision
39-
subA = A[:, outlier_cols].contiguous()
40-
41-
# Dequantize the corresponding weight columns
42-
subB = (
43-
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
44-
.to(A.dtype)
45-
.t()
46-
)
47-
48-
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
49-
50-
else:
51-
# Needed for torch.compile when there are no outliers.
52-
subA = torch.empty(0, device=A.device, dtype=A.dtype)
53-
54-
# Int8 Matmul + Dequant + Bias
55-
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
56-
57-
if subB is not None:
58-
# Add the outlier columns back to the output
59-
output = output.addmm(subA, subB)
60-
61-
return output, subA
62-
63-
6425
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
6526
A, B = B, A
6627

bitsandbytes/backends/default/ops.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,50 @@
1+
from math import prod
12
from typing import Optional
23

34
import torch
45

56
from ..._ops import register_kernel
67

78

9+
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
10+
def _(
11+
A: torch.Tensor,
12+
CA: torch.Tensor,
13+
CB: torch.Tensor,
14+
SCA: torch.Tensor,
15+
SCB: torch.Tensor,
16+
outlier_cols: Optional[torch.Tensor] = None,
17+
bias: Optional[torch.Tensor] = None,
18+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
19+
subB = None
20+
21+
if outlier_cols is not None and outlier_cols.numel():
22+
# Extract the inputs with outliers in original precision
23+
subA = A[:, outlier_cols].contiguous()
24+
25+
# Dequantize the corresponding weight columns
26+
subB = (
27+
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
28+
.to(A.dtype)
29+
.t()
30+
)
31+
32+
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
33+
34+
else:
35+
# Needed for torch.compile when there are no outliers.
36+
subA = torch.empty(0, device=A.device, dtype=A.dtype)
37+
38+
# Int8 Matmul + Dequant + Bias
39+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
40+
41+
if subB is not None:
42+
# Add the outlier columns back to the output
43+
output = output.addmm(subA, subB)
44+
45+
return output, subA
46+
47+
848
@register_kernel("bitsandbytes::int8_scaled_mm", "default")
949
def _(
1050
A: torch.Tensor,
@@ -41,3 +81,33 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor
4181
if out is not None:
4282
result = out.copy_(result)
4383
return result
84+
85+
86+
@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
87+
def _(A: torch.Tensor, threshold=0.0):
88+
rows = prod(A.shape[:-1])
89+
outlier_cols = None
90+
91+
if threshold > 0.0:
92+
outliers = A.abs() >= threshold
93+
94+
if outliers.any():
95+
# Determine which columns contain outliers, and zero out the
96+
# outliers ahead of quantization.
97+
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
98+
A[outliers] = 0
99+
else:
100+
# Needed for torch.compile support.
101+
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
102+
103+
# Get absmax for each row.
104+
row_stats = torch.max(A.abs(), dim=1).values.float()
105+
106+
# Quantize row-wise to int8.
107+
out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)
108+
109+
# Zero out values from outlier columns across all rows.
110+
if rows > 1 and outlier_cols is not None:
111+
out_row[:, outlier_cols] = 0
112+
113+
return out_row, row_stats, outlier_cols

bitsandbytes/nn/modules.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -585,19 +585,28 @@ def __new__(
585585
obj.has_fp16_weights = has_fp16_weights
586586
return obj
587587

588-
def cuda(self, device):
588+
def _quantize(self, device):
589589
if self.has_fp16_weights:
590-
return super().cuda(device)
591-
else:
592-
# We quantize the weight and store in 8bit row-major
593-
B = self.data.contiguous().half().cuda(device)
594-
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
595-
self.data = CB
596-
self.CB = CB
597-
self.SCB = SCB
590+
return super().to(device)
591+
592+
# We quantize the weight and store in 8bit row-major
593+
B = self.data.contiguous().to(device=device, dtype=torch.float16)
594+
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
595+
self.data = CB
596+
self.CB = CB
597+
self.SCB = SCB
598598

599599
return self
600600

601+
def cpu(self):
602+
return self.to(device="cpu")
603+
604+
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
605+
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
606+
607+
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
608+
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
609+
601610
def __deepcopy__(self, memo):
602611
# adjust this if new arguments are added to the constructor
603612
new_instance = type(self).__new__(
@@ -627,8 +636,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
627636
def to(self, *args, **kwargs):
628637
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
629638

630-
if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
631-
return self.cuda(device)
639+
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
640+
return self._quantize(device)
632641
else:
633642
new_param = Int8Params(
634643
super().to(device=device, dtype=dtype, non_blocking=non_blocking),

tests/test_autograd.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@
3232
def test_matmullt(
3333
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
3434
):
35-
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
36-
# TODO: Deprecate/remove?
37-
pytest.skip("switchback_bnb only works on CUDA.")
35+
if device != "cuda":
36+
if funcs[1] == bnb.research.switchback_bnb:
37+
# TODO: Deprecate/remove?
38+
pytest.skip("switchback_bnb only works on CUDA.")
39+
40+
if req_grad[1]:
41+
# This will be deprecated for CUDA in the future. We don't expect
42+
# this to work on any other device.
43+
pytest.skip("Deprecated feature with CUDA support only.")
3844

3945
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
4046
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
@@ -171,7 +177,7 @@ def test_matmul_4bit(
171177
quant_type,
172178
):
173179
if device == "cpu" and quant_type == "fp4":
174-
pytest.skip("Only nf4 is supported on CPU")
180+
pytest.xfail("Only nf4 is supported on CPU")
175181

176182
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
177183
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)

tests/test_functional.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_few_bit_quant(self, device, bits, method):
186186
code = F.create_dynamic_map(True, bits - 0, bits).to(device)
187187
elif method == "quantile":
188188
if device != "cuda":
189-
pytest.xfail("Quantile map only works on CUDA")
189+
pytest.skip("Quantile map only works on CUDA")
190190
values = torch.randn(2048, 2048, device="cuda")
191191
code = F.create_quantile_map(values, bits).cuda()
192192
# for some data types we have no zero
@@ -593,7 +593,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims):
593593

594594
A = A.view(-1, A.shape[-1])
595595

596-
CA, _, statsA, _, _ = F.int8_double_quant(A)
596+
CA, statsA, _ = F.int8_vectorwise_quant(A)
597597
CB, statsB, _ = F.int8_vectorwise_quant(B)
598598
output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
599599

@@ -1102,6 +1102,9 @@ class TestQuantize4BitFunctional:
11021102
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11031103
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
11041104
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1105+
if device == "cpu" and quant_type != "nf4":
1106+
pytest.xfail("fp4 quantization is not supported on CPU")
1107+
11051108
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11061109
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
11071110
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
@@ -1134,6 +1137,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11341137
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11351138
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
11361139
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
1140+
if device == "cpu" and quant_type != "nf4":
1141+
pytest.xfail("fp4 quantization is not supported on CPU")
1142+
11371143
errs1 = []
11381144
errs2 = []
11391145
for i in range(10):
@@ -1206,6 +1212,12 @@ def test_bench_4bit_dequant(self, quant_type):
12061212
)
12071213
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
12081214
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
1215+
if device == "cpu":
1216+
if storage_type != "nf4":
1217+
pytest.xfail("fp4 quantization is not supported on CPU")
1218+
if quant_storage != torch.uint8:
1219+
pytest.xfail("Only uint8 storage is supported on CPU")
1220+
12091221
errs1 = []
12101222
errs2 = []
12111223
errs3 = []
@@ -1216,7 +1228,11 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
12161228
max_errs2 = []
12171229
max_errs3 = []
12181230

1219-
for i in range(100):
1231+
# Large number of iterations is excessive and slow on CPU.
1232+
# Keep for CUDA for now.
1233+
iters = 100 if device == "cuda" else 10
1234+
1235+
for i in range(iters):
12201236
if kind == "fc1":
12211237
A = torch.randn(1, dim, dtype=dtype, device=device)
12221238
B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
@@ -1337,6 +1353,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
13371353
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
13381354
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
13391355
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
1356+
if device == "cpu" and storage_type != "nf4":
1357+
pytest.xfail("fp4 quantization is not supported on CPU")
1358+
13401359
dims = 10
13411360
torch.random.manual_seed(np.random.randint(0, 412424242))
13421361
dims = get_test_dims(0, 8192, n=dims)

tests/test_linear4bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
2525
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
2626
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
27-
if device == "cpu":
28-
pytest.xfail("Dequantization is not yet implemented for CPU")
27+
if device == "cpu" and quant_type == "fp4":
28+
pytest.xfail("FP4 is not supported for CPU")
2929

3030
original_dtype = torch.float16
3131
compute_dtype = None

tests/test_linear8bitlt.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
2323
@pytest.mark.parametrize("device", get_available_devices())
2424
def test_linear_no_igemmlt(device):
25-
if device == "cpu":
26-
pytest.xfail("Not yet implemented on CPU")
27-
2825
linear = torch.nn.Linear(1024, 3072)
2926
x = torch.randn(3, 1024, dtype=torch.half)
3027
linear_custom = Linear8bitLt(
@@ -81,8 +78,8 @@ def test_linear_serialization(
8178
save_before_forward,
8279
load_before_cuda,
8380
):
84-
if device == "cpu":
85-
pytest.xfail("Not yet implemented on CPU")
81+
if device != "cuda" and has_fp16_weights:
82+
pytest.skip("has_fp16_weights is only supported on CUDA and is deprecated")
8683

8784
linear = torch.nn.Linear(32, 96)
8885
# TODO: Fallback for bad shapes
@@ -111,7 +108,7 @@ def test_linear_serialization(
111108
if save_before_forward:
112109
bytes_8bit = torch_save_to_buffer(linear_custom)
113110

114-
x_first = x.clone().cuda().requires_grad_(True)
111+
x_first = x.clone().to(device).requires_grad_(True)
115112
fx_first = linear_custom(x_first).float()
116113
grad_proj = torch.randn_like(fx_first)
117114
(fx_first * grad_proj).mean().backward()
@@ -157,11 +154,11 @@ def test_linear_serialization(
157154
if not load_before_cuda:
158155
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
159156

160-
x_second = x.clone().cuda().requires_grad_(True)
157+
x_second = x.clone().to(device).requires_grad_(True)
161158
fx_second = new_linear_custom(x_second).float()
162159
(fx_second * grad_proj).mean().backward()
163160

164-
x_third = x.clone().cuda().requires_grad_(True)
161+
x_third = x.clone().to(device).requires_grad_(True)
165162
fx_third = new_linear_custom2(x_third).float()
166163
(fx_third * grad_proj).mean().backward()
167164

0 commit comments

Comments
 (0)