Skip to content

Commit 5bcc1dd

Browse files
committed
save/load 4bit squashed
1 parent 61a4a20 commit 5bcc1dd

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

bitsandbytes/functional.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,36 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
578578
self.state2 = state2
579579
self.nested = state2 is not None
580580

581+
@classmethod
582+
def from_kwargs(cls, kwargs, device):
583+
584+
tensor2str = lambda xx: ''.join([chr(x) for x in xx]).strip('.')
585+
586+
kwargs = {k.split('.')[-1] :v for k, v in kwargs.items()}
587+
588+
if 'nested_absmax' in kwargs:
589+
offset = kwargs['nested_offset']
590+
state2 = cls(
591+
absmax=kwargs['nested_absmax'].to(device),
592+
code=kwargs['nested_code'].to(device),
593+
blocksize=kwargs['nested_blocksize'].item(),
594+
dtype=getattr(torch, tensor2str(kwargs['nested_dtype'])),
595+
)
596+
else:
597+
offset, state2 = None, None
598+
599+
quant_state = cls(
600+
absmax=kwargs['absmax'].to(device),
601+
shape=torch.Size(kwargs['shape']),
602+
dtype=getattr(torch, tensor2str(kwargs['dtype'])),
603+
blocksize=kwargs['blocksize'].item(),
604+
offset=offset,
605+
state2=state2,
606+
quant_type=tensor2str(kwargs['quant_type']),
607+
code=kwargs['code'].to(device),
608+
)
609+
return quant_state
610+
581611
def to(self, device):
582612
# make sure the quantization state is on the right device
583613
self.absmax = self.absmax.to(device)

bitsandbytes/nn/modules.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import Tensor, device, dtype, nn
1111

1212
import bitsandbytes as bnb
13-
import bitsandbytes.functional
13+
from bitsandbytes.functional import QuantState
1414
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
1515
from bitsandbytes.optim import GlobalOptimManager
1616
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
@@ -140,6 +140,7 @@ def forward(self, input: Tensor) -> Tensor:
140140
return emb
141141

142142
class Params4bit(torch.nn.Parameter):
143+
143144
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
144145
if data is None:
145146
data = torch.empty(0)
@@ -151,6 +152,18 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
151152
self.quant_state = quant_state
152153
self.data = data
153154
return self
155+
156+
@classmethod
157+
def from_prequantized(cls, quantized_stats, data=None, requires_grad=False, device='cuda', **kwargs):
158+
if data is None:
159+
data = quantized_stats.pop('weight')
160+
self = torch.Tensor._make_subclass(cls, data.to(device))
161+
self.requires_grad = requires_grad
162+
self.quant_state = QuantState.from_kwargs(kwargs=quantized_stats, device=device)
163+
self.blocksize = self.quant_state.blocksize
164+
self.compress_statistics = self.quant_state.nested
165+
self.quant_type = self.quant_state.quant_type
166+
return self
154167

155168
def cuda(self, device):
156169
w = self.data.contiguous().half().cuda(device)
@@ -211,6 +224,38 @@ def set_compute_type(self, x):
211224
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
212225
warnings.filterwarnings('ignore', message='.*inference or training')
213226

227+
228+
def _update_buffers(self):
229+
230+
def string_to_tensor(s):
231+
"""stores string as ints for serialization. assumes codes fit int16"""
232+
return torch.tensor([ord(x) for x in s], dtype=torch.int16)
233+
234+
if getattr(self.weight, 'quant_state', None) is not None:
235+
weight_quant_state = self.weight.quant_state
236+
self.register_buffer('absmax', weight_quant_state.absmax)
237+
self.register_buffer('shape', torch.tensor(weight_quant_state.shape))
238+
self.register_buffer('dtype', string_to_tensor(str(weight_quant_state.dtype).strip('torch')))
239+
self.register_buffer('blocksize', torch.tensor(weight_quant_state.blocksize))
240+
self.register_buffer('quant_type', string_to_tensor(weight_quant_state.quant_type))
241+
self.register_buffer('code', weight_quant_state.code)
242+
243+
if weight_quant_state.nested:
244+
self.register_buffer('nested_offset', weight_quant_state.offset)
245+
self.register_buffer('nested_absmax', weight_quant_state.state2.absmax)
246+
self.register_buffer('nested_code', weight_quant_state.state2.code)
247+
self.register_buffer('nested_blocksize', torch.tensor(weight_quant_state.state2.blocksize))
248+
self.register_buffer('nested_dtype', string_to_tensor(str(weight_quant_state.state2.dtype).strip('torch')))
249+
250+
251+
def _save_to_state_dict(self, destination, prefix, keep_vars):
252+
"""
253+
fill state_dict with components of nf4
254+
TODO: test with other 4-bit Q-types
255+
"""
256+
self._update_buffers() # link the quant_state items with _buffers
257+
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
258+
214259
def forward(self, x: torch.Tensor):
215260
# weights are cast automatically as Int8Params, but the bias has to be cast manually
216261
if self.bias is not None and self.bias.dtype != x.dtype:

tests/test_linear4bit.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
from contextlib import nullcontext
3+
from itertools import product
4+
from tempfile import TemporaryDirectory
5+
6+
import pytest
7+
import torch
8+
9+
import bitsandbytes as bnb
10+
from bitsandbytes import functional as F
11+
from bitsandbytes.nn.modules import Linear4bit
12+
13+
14+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
15+
@pytest.mark.parametrize(
16+
"quant_type, compress_statistics, bias",
17+
list(product(["nf4", "fp4"], [False, True], [False, True])),
18+
)
19+
def test_linear4_state_dict(quant_type, compress_statistics, bias):
20+
original_dtype = torch.float16
21+
compute_dtype = None
22+
device = "cuda"
23+
layer_shape = (300, 400)
24+
25+
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
26+
27+
# Quantizing original layer
28+
linear_q = bnb.nn.Linear4bit(
29+
linear.in_features,
30+
linear.out_features,
31+
bias=bias,
32+
compute_dtype=compute_dtype,
33+
compress_statistics=compress_statistics,
34+
quant_type=quant_type,
35+
device=device,
36+
)
37+
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
38+
linear_q.weight = new_weight.to(device)
39+
if bias:
40+
linear_q.bias.data = linear.bias.data.to(device)
41+
42+
sd = linear_q.state_dict()
43+
44+
# restoring from state_dict:
45+
46+
sd = linear_q.state_dict()
47+
bias_data2 = sd.pop("bias", None)
48+
weight_data2 = sd.pop("weight")
49+
50+
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
51+
52+
linear_q2 = bnb.nn.Linear4bit(
53+
linear.in_features,
54+
linear.out_features,
55+
bias=bias,
56+
compute_dtype=compute_dtype,
57+
compress_statistics=compress_statistics,
58+
quant_type=quant_type,
59+
device=device,
60+
)
61+
linear_q2.weight = weight2.to(device)
62+
if bias:
63+
linear_q2.bias.data = bias_data2
64+
65+
# matching
66+
a, b = linear_q.weight, linear_q2.weight
67+
68+
assert a.device == b.device
69+
assert a.dtype == b.dtype
70+
assert torch.equal(a, b)
71+
72+
q0 = a.quant_state
73+
q1 = b.quant_state
74+
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
75+
c, d = getattr(q0, attr), getattr(q1, attr)
76+
if isinstance(c, torch.Tensor):
77+
assert torch.equal(c, d)
78+
else:
79+
assert c == d, f"{c} != {d}"
80+
81+
if q0.state2 is not None:
82+
for attr in ('code', 'dtype', 'blocksize', 'absmax'):
83+
c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
84+
if isinstance(c, torch.Tensor):
85+
assert torch.equal(c, d)
86+
else:
87+
assert c == d, f"{c} != {d}"
88+
89+
if bias:
90+
a, b = linear_q.bias, linear_q2.bias
91+
assert a.device == b.device
92+
assert a.dtype == b.dtype
93+
assert torch.equal(a, b)
94+
95+
# Forward test
96+
x = torch.rand(42, linear_q.shape[-1], device=device)
97+
a = linear_q(x)
98+
b = linear_q2(x)
99+
assert a.device == b.device
100+
assert a.dtype == b.dtype
101+
assert torch.equal(a, b)
102+
103+
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
104+
with TemporaryDirectory() as tmpdir:
105+
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
106+
state_path = os.path.join(tmpdir, "state.pth")
107+
torch.save(linear.state_dict(), state_path)
108+
torch.save(linear_q.state_dict(), state_path_4bit)
109+
110+
size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
111+
state_path_4bit
112+
)
113+
size_ratio = size_4 / size_orig
114+
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
115+
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
116+
assert size_ratio < target_compression, ratio_error_msg

0 commit comments

Comments
 (0)