Skip to content

Commit 4fae73d

Browse files
Additional device agnostic tests
1 parent 996a26f commit 4fae73d

File tree

3 files changed

+102
-279
lines changed

3 files changed

+102
-279
lines changed

tests/test_linear4bit.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
import bitsandbytes as bnb
10-
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
10+
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
1111

1212
storage = {
1313
"uint8": torch.uint8,
@@ -17,15 +17,18 @@
1717
}
1818

1919

20+
@pytest.mark.parametrize("device", get_available_devices())
2021
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
21-
@pytest.mark.parametrize("bias", TRUE_FALSE)
22-
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
22+
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
23+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
2324
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
24-
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
25-
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
25+
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
26+
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
27+
if device == "cpu":
28+
pytest.xfail("Dequantization is not yet implemented for CPU")
29+
2630
original_dtype = torch.float16
2731
compute_dtype = None
28-
device = "cuda"
2932
layer_shape = (300, 400)
3033

3134
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
5255
# restoring from state_dict:
5356
bias_data2 = sd.pop("bias", None)
5457
weight_data2 = sd.pop("weight")
55-
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
58+
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
5659

5760
# creating new layer with same params:
5861
linear_q2 = bnb.nn.Linear4bit(
@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
174177
assert size_ratio < target_compression, ratio_error_msg
175178

176179

177-
def test_copy_param():
178-
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
179-
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
180+
@pytest.mark.parametrize("device", get_available_devices())
181+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
182+
@pytest.mark.parametrize("blocksize", [64, 128])
183+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
184+
def test_copy_param(device, quant_type, blocksize, compress_statistics):
185+
if device == "cpu":
186+
if compress_statistics:
187+
pytest.skip("Currently segfaults on CPU")
188+
if quant_type == "fp4":
189+
pytest.xfail("FP4 not supported on CPU")
190+
191+
tensor = torch.linspace(1, blocksize, blocksize)
192+
param = bnb.nn.Params4bit(
193+
data=tensor,
194+
quant_type=quant_type,
195+
blocksize=blocksize,
196+
compress_statistics=compress_statistics,
197+
requires_grad=False,
198+
).to(device)
180199

181200
shallow_copy_param = copy.copy(param)
182201
assert param.quant_state is shallow_copy_param.quant_state
183202
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
184203

185204

186-
def test_deepcopy_param():
187-
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
188-
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
205+
@pytest.mark.parametrize("device", get_available_devices())
206+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
207+
@pytest.mark.parametrize("blocksize", [64, 128])
208+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
209+
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
210+
if device == "cpu":
211+
if compress_statistics:
212+
pytest.skip("Currently segfaults on CPU")
213+
if quant_type == "fp4":
214+
pytest.xfail("FP4 not supported on CPU")
215+
216+
tensor = torch.linspace(1, blocksize, blocksize)
217+
param = bnb.nn.Params4bit(
218+
data=tensor,
219+
quant_type=quant_type,
220+
blocksize=blocksize,
221+
compress_statistics=compress_statistics,
222+
requires_grad=False,
223+
).to(device)
189224
dict_keys_before = set(param.__dict__.keys())
190225
copy_param = copy.deepcopy(param)
191226
dict_keys_after = set(param.__dict__.keys())
@@ -199,12 +234,27 @@ def test_deepcopy_param():
199234
assert dict_keys_before == dict_keys_copy
200235

201236

202-
def test_params4bit_real_serialization():
203-
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
204-
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
237+
@pytest.mark.parametrize("device", get_available_devices())
238+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
239+
@pytest.mark.parametrize("blocksize", [64, 128])
240+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
241+
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
242+
if device == "cpu":
243+
if compress_statistics:
244+
pytest.skip("Currently segfaults on CPU")
245+
if quant_type == "fp4":
246+
pytest.xfail("FP4 not supported on CPU")
247+
248+
original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
249+
original_param = bnb.nn.Params4bit(
250+
data=original_tensor,
251+
quant_type=quant_type,
252+
blocksize=blocksize,
253+
compress_statistics=compress_statistics,
254+
)
205255
dict_keys_before = set(original_param.__dict__.keys())
206256

207-
original_param.cuda(0) # move to CUDA to trigger quantization
257+
original_param.to(device) # change device to trigger quantization
208258

209259
serialized_param = pickle.dumps(original_param)
210260
deserialized_param = pickle.loads(serialized_param)

tests/test_linear8bitlt.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from bitsandbytes.nn.modules import Linear8bitLt
1212
from tests.helpers import (
1313
TRUE_FALSE,
14+
get_available_devices,
1415
id_formatter,
1516
torch_load_from_buffer,
1617
torch_save_to_buffer,
@@ -19,7 +20,11 @@
1920

2021
# contributed by Alex Borzunov, see:
2122
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
22-
def test_linear_no_igemmlt():
23+
@pytest.mark.parametrize("device", get_available_devices())
24+
def test_linear_no_igemmlt(device):
25+
if device == "cpu":
26+
pytest.xfail("Not yet implemented on CPU")
27+
2328
linear = torch.nn.Linear(1024, 3072)
2429
x = torch.randn(3, 1024, dtype=torch.half)
2530
linear_custom = Linear8bitLt(
@@ -29,6 +34,8 @@ def test_linear_no_igemmlt():
2934
has_fp16_weights=False,
3035
threshold=6.0,
3136
)
37+
38+
# TODO: Remove, this is no longer implemented
3239
linear_custom.state.force_no_igemmlt = True
3340

3441
linear_custom.weight = bnb.nn.Int8Params(
@@ -37,11 +44,11 @@ def test_linear_no_igemmlt():
3744
has_fp16_weights=False,
3845
).to(linear.weight.dtype)
3946
linear_custom.bias = linear.bias
40-
linear_custom = linear_custom.cuda()
41-
linear = linear.half().cuda()
47+
linear_custom = linear_custom.to(device)
48+
linear = linear.half().to(device)
4249

43-
x_ref = x.clone().cuda().requires_grad_(True)
44-
x_ours = x.clone().cuda().requires_grad_(True)
50+
x_ref = x.clone().to(device).requires_grad_(True)
51+
x_ours = x.clone().to(device).requires_grad_(True)
4552
fx_ref = linear(x_ref).float()
4653
grad_proj = torch.randn_like(fx_ref)
4754
(fx_ref * grad_proj).mean().backward()
@@ -58,18 +65,23 @@ def test_linear_no_igemmlt():
5865
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
5966

6067

68+
@pytest.mark.parametrize("device", get_available_devices())
6169
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
6270
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
6371
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
6472
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
6573
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
6674
def test_linear_serialization(
75+
device,
6776
has_fp16_weights,
6877
serialize_before_forward,
6978
deserialize_before_cuda,
7079
save_before_forward,
7180
load_before_cuda,
7281
):
82+
if device == "cpu":
83+
pytest.xfail("Not yet implemented on CPU")
84+
7385
linear = torch.nn.Linear(32, 96)
7486
# TODO: Fallback for bad shapes
7587
x = torch.randn(4, 32, dtype=torch.half)
@@ -89,7 +101,7 @@ def test_linear_serialization(
89101
has_fp16_weights=has_fp16_weights,
90102
)
91103
linear_custom.bias = linear.bias
92-
linear_custom = linear_custom.cuda()
104+
linear_custom = linear_custom.to(device)
93105

94106
if serialize_before_forward:
95107
state_dict_8bit = linear_custom.state_dict()
@@ -135,7 +147,7 @@ def test_linear_serialization(
135147
if load_before_cuda:
136148
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
137149

138-
new_linear_custom = new_linear_custom.cuda()
150+
new_linear_custom = new_linear_custom.to(device)
139151

140152
if not deserialize_before_cuda:
141153
new_linear_custom.load_state_dict(new_state_dict, strict=True)

0 commit comments

Comments
 (0)