Skip to content

Commit a1c0844

Browse files
authored
adding whole Linear8bitLt/Linear4bit module save/load serialization (#1099)
1 parent f9eba9c commit a1c0844

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,9 @@ def __new__(
449449
cls.SCB = None
450450
if data is None:
451451
data = torch.empty(0)
452-
return torch.Tensor._make_subclass(cls, data, requires_grad)
452+
obj = torch.Tensor._make_subclass(cls, data, requires_grad)
453+
obj.CB, obj.SCB = cls.CB, cls.SCB
454+
return obj
453455

454456
def cuda(self, device):
455457
if self.has_fp16_weights:

tests/test_linear4bit.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
from io import BytesIO
23
import os
34
import pickle
45
from tempfile import TemporaryDirectory
@@ -16,12 +17,24 @@
1617
"float32": torch.float32,
1718
}
1819

20+
def torch_save_to_buffer(obj):
21+
buffer = BytesIO()
22+
torch.save(obj, buffer)
23+
buffer.seek(0)
24+
return buffer
25+
26+
def torch_load_from_buffer(buffer):
27+
buffer.seek(0)
28+
obj = torch.load(buffer)
29+
buffer.seek(0)
30+
return obj
1931

2032
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
2133
@pytest.mark.parametrize("bias", TRUE_FALSE)
2234
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
2335
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
24-
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
36+
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
37+
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
2538
original_dtype = torch.float16
2639
compute_dtype = None
2740
device = "cuda"
@@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
124137
assert a.dtype == b.dtype
125138
assert torch.equal(a, b)
126139

140+
if save_before_forward:
141+
bytes_4bit = torch_save_to_buffer(linear_q)
142+
127143
# Forward test
128144
x = torch.rand(42, layer_shape[0], device=device)
129145
a = linear_q(x)
@@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
136152
assert torch.equal(a, b)
137153
assert torch.equal(a, c)
138154

155+
if not save_before_forward:
156+
bytes_4bit = torch_save_to_buffer(linear_q)
157+
linear_q3 = torch_load_from_buffer(bytes_4bit)
158+
139159
# Test moving to CPU and back to GPU
140160
linear_q2.to("cpu")
141161
linear_q2.to(device)
@@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
144164
assert c.device == d.device
145165
assert torch.equal(c, d)
146166

167+
d = linear_q3(x)
168+
assert c.dtype == d.dtype
169+
assert c.device == d.device
170+
assert torch.equal(c, d)
171+
147172
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
148173
with TemporaryDirectory() as tmpdir:
149174
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")

tests/test_linear8bitlt.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from contextlib import nullcontext
2+
from io import BytesIO
23
import os
34
from tempfile import TemporaryDirectory
45

@@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
6566
assert linear_custom.state.CB is not None
6667
assert linear_custom.state.CxB is None
6768

69+
def torch_save_to_buffer(obj):
70+
buffer = BytesIO()
71+
torch.save(obj, buffer)
72+
buffer.seek(0)
73+
return buffer
74+
75+
def torch_load_from_buffer(buffer):
76+
buffer.seek(0)
77+
obj = torch.load(buffer)
78+
buffer.seek(0)
79+
return obj
6880

6981
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
7082
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
7183
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
7284
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
73-
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
85+
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
86+
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
87+
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda):
7488
linear = torch.nn.Linear(32, 96)
7589
x = torch.randn(3, 32, dtype=torch.half)
7690

@@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
93107
if serialize_before_forward:
94108
state_dict_8bit = linear_custom.state_dict()
95109

110+
if save_before_forward:
111+
bytes_8bit = torch_save_to_buffer(linear_custom)
112+
96113
x_first = x.clone().cuda().requires_grad_(True)
97114
fx_first = linear_custom(x_first).float()
98115
grad_proj = torch.randn_like(fx_first)
@@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
101118
if not serialize_before_forward:
102119
state_dict_8bit = linear_custom.state_dict()
103120

121+
if not save_before_forward:
122+
bytes_8bit = torch_save_to_buffer(linear_custom)
123+
104124
with TemporaryDirectory() as tmpdir:
105125
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
106126
state_path = os.path.join(tmpdir, "state.pth")
@@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
127147
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
128148
new_linear_custom.load_state_dict(new_state_dict, strict=True)
129149

150+
if load_before_cuda:
151+
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
152+
130153
new_linear_custom = new_linear_custom.cuda()
131154

132155
if not deserialize_before_cuda:
133156
new_linear_custom.load_state_dict(new_state_dict, strict=True)
134157

158+
if not load_before_cuda:
159+
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
160+
135161
x_second = x.clone().cuda().requires_grad_(True)
136162
fx_second = new_linear_custom(x_second).float()
137163
(fx_second * grad_proj).mean().backward()
138164

165+
x_third = x.clone().cuda().requires_grad_(True)
166+
fx_third = new_linear_custom2(x_third).float()
167+
(fx_third * grad_proj).mean().backward()
168+
139169
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
140170
if has_fp16_weights or not deserialize_before_cuda:
141171
assert torch.allclose(fx_first, fx_second, atol=1e-5)
142172
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
173+
assert torch.allclose(fx_first, fx_third, atol=1e-5)
174+
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)

0 commit comments

Comments
 (0)