Skip to content

Commit ed6f3eb

Browse files
authored
Merge pull request #159 from TimDettmers/serialize_8bit
Implement proper serialization of Linear8bitLt
2 parents b0ec20c + dcecbb2 commit ed6f3eb

File tree

3 files changed

+144
-11
lines changed

3 files changed

+144
-11
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def supports_igemmlt(device: torch.device) -> bool:
234234

235235
@dataclass
236236
class MatmulLtState:
237-
tile_indices: Optional[torch.Tensor] = None
237+
_tile_indices: Optional[torch.Tensor] = None
238238
force_no_igemmlt: bool = False
239239
CB = None
240240
CxB = None
@@ -274,6 +274,15 @@ def get_tile_size(self):
274274
), f"please find this assert and manually enter tile size for {self.formatB}"
275275
return (8, 32) if self.formatB == "col_turing" else (32, 32)
276276

277+
@property
278+
def tile_indices(self):
279+
if self._tile_indices is None:
280+
device = self.CxB.device
281+
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
282+
with torch.no_grad():
283+
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
284+
return self._tile_indices
285+
277286

278287
class MatMul8bitLt(torch.autograd.Function):
279288
# forward is the same, but we added the fallback for pre-turing GPUs
@@ -466,13 +475,6 @@ def backward(ctx, grad_output):
466475
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
467476
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
468477
elif state.CxB is not None:
469-
470-
if state.tile_indices is None:
471-
order, tile_size = state.formatB, state.get_tile_size()
472-
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
473-
with torch.no_grad():
474-
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
475-
476478
CB = (
477479
undo_layout(state.CxB, state.tile_indices)
478480
.to(ctx.dtype_A)

bitsandbytes/nn/modules.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from torch import Tensor, device, dtype, nn
1010

1111
import bitsandbytes as bnb
12+
import bitsandbytes.functional
13+
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
1214
from bitsandbytes.optim import GlobalOptimManager
1315

1416
T = TypeVar("T", bound="torch.nn.Module")
@@ -224,6 +226,53 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
224226

225227
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
226228

229+
def _save_to_state_dict(self, destination, prefix, keep_vars):
230+
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
231+
# reorder weight layout back from ampere/turing to row
232+
reorder_layout = True
233+
weight_clone = self.weight.data.clone()
234+
else:
235+
reorder_layout = False
236+
237+
try:
238+
if reorder_layout:
239+
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
240+
241+
super()._save_to_state_dict(destination, prefix, keep_vars)
242+
243+
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
244+
weight_name = "SCB"
245+
246+
# case 1: .cuda was called, SCB is in self.weight
247+
param_from_weight = getattr(self.weight, weight_name)
248+
# case 2: self.init_8bit_state was called, SCB is in self.state
249+
param_from_state = getattr(self.state, weight_name)
250+
251+
key_name = prefix + f"{weight_name}"
252+
if param_from_weight is not None:
253+
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
254+
elif not self.state.has_fp16_weights and param_from_state is not None:
255+
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
256+
finally:
257+
if reorder_layout:
258+
self.weight.data = weight_clone
259+
260+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
261+
missing_keys, unexpected_keys, error_msgs):
262+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
263+
error_msgs)
264+
for key in unexpected_keys:
265+
input_name = key[len(prefix):]
266+
if input_name == "SCB":
267+
if self.weight.SCB is None:
268+
# buffers not yet initialized, can't call them directly without
269+
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
270+
"not supported. Please call module.cuda() before module.load_state_dict()")
271+
272+
input_param = state_dict[key]
273+
self.weight.SCB.copy_(input_param)
274+
unexpected_keys.remove(key)
275+
227276
def init_8bit_state(self):
228277
self.state.CB = self.weight.CB
229278
self.state.SCB = self.weight.SCB

tests/test_linear8bitlt.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
import bitsandbytes as bnb
1+
import os
2+
from contextlib import nullcontext
3+
from itertools import product
4+
from tempfile import TemporaryDirectory
5+
26
import pytest
37
import torch
4-
from bitsandbytes import functional as F
58

9+
import bitsandbytes as bnb
10+
from bitsandbytes import functional as F
611
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
712
from bitsandbytes.nn.modules import Linear8bitLt
813

14+
915
# contributed by Alex Borzunov, see:
1016
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
1117

@@ -26,6 +32,7 @@ def test_layout_exact_match():
2632
assert restored_x.is_contiguous()
2733
assert torch.all(torch.eq(restored_x, x))
2834

35+
2936
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
3037
def test_linear_no_igemmlt():
3138
linear = torch.nn.Linear(1024, 3072)
@@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
4350
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
4451
).to(linear.weight.dtype)
4552
linear_custom.bias = linear.bias
46-
linear = linear_custom.cuda()
53+
linear_custom = linear_custom.cuda()
4754
linear = linear.half().cuda()
4855

4956
x_ref = x.clone().cuda().requires_grad_(True)
@@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
5966
assert not linear_custom.state.has_fp16_weights
6067
assert linear_custom.state.CB is not None
6168
assert linear_custom.state.CxB is None
69+
70+
71+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
72+
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
73+
list(product([False, True], [False, True], [False, True], [False, True])))
74+
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
75+
linear = torch.nn.Linear(32, 96)
76+
x = torch.randn(3, 32, dtype=torch.half)
77+
78+
linear_custom = Linear8bitLt(
79+
linear.in_features,
80+
linear.out_features,
81+
linear.bias is not None,
82+
has_fp16_weights=has_fp16_weights,
83+
threshold=6.0,
84+
)
85+
if force_no_igemmlt:
86+
linear_custom.state.force_no_igemmlt = True
87+
88+
linear_custom.weight = bnb.nn.Int8Params(
89+
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
90+
)
91+
linear_custom.bias = linear.bias
92+
linear_custom = linear_custom.cuda()
93+
94+
if serialize_before_forward:
95+
state_dict_8bit = linear_custom.state_dict()
96+
97+
x_first = x.clone().cuda().requires_grad_(True)
98+
fx_first = linear_custom(x_first).float()
99+
grad_proj = torch.randn_like(fx_first)
100+
(fx_first * grad_proj).mean().backward()
101+
102+
if not serialize_before_forward:
103+
state_dict_8bit = linear_custom.state_dict()
104+
105+
with TemporaryDirectory() as tmpdir:
106+
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
107+
state_path = os.path.join(tmpdir, "state.pth")
108+
109+
torch.save(linear.state_dict(), state_path)
110+
torch.save(state_dict_8bit, state_path_8bit)
111+
112+
if not has_fp16_weights:
113+
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
114+
115+
new_state_dict = torch.load(state_path_8bit)
116+
117+
new_linear_custom = Linear8bitLt(
118+
linear.in_features,
119+
linear.out_features,
120+
linear.bias is not None,
121+
has_fp16_weights=has_fp16_weights,
122+
threshold=6.0,
123+
)
124+
if force_no_igemmlt:
125+
new_linear_custom.state.force_no_igemmlt = True
126+
127+
if deserialize_before_cuda:
128+
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
129+
new_linear_custom.load_state_dict(new_state_dict, strict=True)
130+
131+
new_linear_custom = new_linear_custom.cuda()
132+
133+
if not deserialize_before_cuda:
134+
new_linear_custom.load_state_dict(new_state_dict, strict=True)
135+
136+
x_second = x.clone().cuda().requires_grad_(True)
137+
fx_second = new_linear_custom(x_second).float()
138+
(fx_second * grad_proj).mean().backward()
139+
140+
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
141+
if has_fp16_weights or not deserialize_before_cuda:
142+
assert torch.allclose(fx_first, fx_second, atol=1e-5)
143+
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)

0 commit comments

Comments
 (0)