Skip to content

Commit 996a26f

Browse files
Make test suite more device-agnostic
1 parent 90c65fb commit 996a26f

File tree

5 files changed

+157
-74
lines changed

5 files changed

+157
-74
lines changed

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
341341
for i in range(gap):
342342
values.append(0)
343343
values.sort()
344-
code = torch.Tensor(values)
344+
code = torch.tensor(values)
345345
code /= code.max()
346346

347347
return code

tests/helpers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from io import BytesIO
23
from itertools import product
34
import random
@@ -13,6 +14,34 @@
1314
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
1415

1516

17+
@functools.cache
18+
def get_available_devices():
19+
devices = ["cpu"]
20+
21+
if hasattr(torch, "accelerator"):
22+
# PyTorch 2.6+ - determine accelerator using agnostic API.
23+
if torch.accelerator.is_available():
24+
devices += [str(torch.accelerator.current_accelerator())]
25+
else:
26+
if torch.cuda.is_available():
27+
devices += ["cuda"]
28+
29+
if torch.backends.mps.is_available():
30+
devices += ["mps"]
31+
32+
if hasattr(torch, "xpu") and torch.xpu.is_available():
33+
devices += ["xpu"]
34+
35+
custom_backend_name = torch._C._get_privateuse1_backend_name()
36+
custom_backend_module = getattr(torch, custom_backend_name, None)
37+
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)
38+
39+
if custom_backend_is_available_fn and custom_backend_module.is_available():
40+
devices += [custom_backend_name]
41+
42+
return devices
43+
44+
1645
def torch_save_to_buffer(obj):
1746
buffer = BytesIO()
1847
torch.save(obj, buffer)

tests/test_autograd.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
BOOLEAN_TRIPLES,
77
TRUE_FALSE,
88
describe_dtype,
9+
get_available_devices,
910
id_formatter,
1011
)
1112

1213
TRANSPOSE_VALS = [(False, True), (False, False)]
1314

1415

16+
@pytest.mark.parametrize("device", get_available_devices())
1517
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
1618
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
1719
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@@ -27,32 +29,38 @@
2729
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
2830
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
2931
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
30-
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
32+
def test_matmullt(
33+
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
34+
):
35+
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
36+
# TODO: Deprecate/remove?
37+
pytest.skip("switchback_bnb only works on CUDA.")
38+
3139
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
3240
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
33-
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
41+
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
3442
if has_bias == False:
3543
req_grad = list(req_grad)
3644
req_grad[2] = False
3745

3846
for i in range(3):
3947
# normal multiply
4048
if funcs[0] in [torch.mm, torch.matmul]:
41-
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
49+
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
4250
if decomp == 6.0:
4351
with torch.no_grad():
4452
A[:, outlier_dim] = 6.0
45-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
53+
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
4654
target = torch.randn(
4755
size=(dim2, dim4),
48-
device="cuda",
56+
device=device,
4957
requires_grad=req_grad[1],
5058
dtype=dtype,
5159
)
5260
bias = None
5361
bias2 = None
5462
if has_bias:
55-
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
63+
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
5664
bias2 = bias.clone()
5765
torch.nn.init.xavier_uniform_(B)
5866
B2 = B.clone()
@@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
9199
if has_fp16_weights:
92100
if any(req_grad):
93101
out_bnb.data.copy_(out_torch)
94-
torch.cuda.synchronize()
102+
if device == "cuda":
103+
torch.cuda.synchronize()
95104
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
96105
loss_bnb.backward()
97106
gradA1 = A.grad
@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
135144
torch.testing.assert_close(gradBias1, gradBias2)
136145

137146

147+
@pytest.mark.parametrize("device", get_available_devices())
138148
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
139149
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
140150
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
147157
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
148158
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
149159
def test_matmul_4bit(
160+
device,
150161
dim1,
151162
dim2,
152163
dim3,
@@ -159,6 +170,9 @@ def test_matmul_4bit(
159170
compress_statistics,
160171
quant_type,
161172
):
173+
if device == "cpu" and quant_type == "fp4":
174+
pytest.skip("Only nf4 is supported on CPU")
175+
162176
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
163177
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
164178
if has_bias == False:
@@ -168,13 +182,13 @@ def test_matmul_4bit(
168182
for i in range(3):
169183
# normal multiply
170184
if funcs[0] in [torch.mm, torch.matmul]:
171-
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
172-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
173-
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
185+
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
186+
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
187+
target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
174188
bias = None
175189
bias2 = None
176190
if has_bias:
177-
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
191+
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
178192
bias2 = bias.clone()
179193
torch.nn.init.xavier_uniform_(B)
180194

@@ -204,7 +218,8 @@ def test_matmul_4bit(
204218
# assert err < 0.20
205219
if any(req_grad):
206220
out_bnb.data.copy_(out_torch)
207-
torch.cuda.synchronize()
221+
if device == "cuda":
222+
torch.cuda.synchronize()
208223
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
209224
loss_bnb.backward()
210225
gradA1 = A.grad

0 commit comments

Comments
 (0)