Skip to content

Commit 726f147

Browse files
authored
Merge pull request #864 from poedator/save4_fixes
fixes to recent PR #753
2 parents f1ef74f + 5486053 commit 726f147

File tree

3 files changed

+47
-59
lines changed

3 files changed

+47
-59
lines changed

bitsandbytes/functional.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,14 +567,14 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
567567

568568
return out
569569

570+
570571
class QuantState:
571572
"""container for quantization state components to work with Params4bit and similar clases"""
572573
valid_quant_types = ('fp4', 'nf4')
573-
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
574-
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state',
575-
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
574+
valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types]
575+
valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type',
576+
'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
576577

577-
578578
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
579579
self.absmax = absmax
580580
self.shape = shape
@@ -585,7 +585,7 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non
585585
self.offset = offset
586586
self.state2 = state2
587587
self.nested = state2 is not None
588-
588+
589589
def __get_item__(self, idx):
590590
"""
591591
ensures compatibility with older quant state scheme with nested lists.
@@ -598,29 +598,32 @@ def __get_item__(self, idx):
598598
else:
599599
list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type]
600600
return list_repr[idx]
601-
601+
602602
@classmethod
603603
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState':
604604
"""
605605
unpacks components of state_dict into QuantState
606606
where necessary, convert into strings, torch.dtype, ints, etc.
607607
608608
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
609-
609+
610610
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
611611
"""
612612

613613
# unpacking tensor with non-tensor components
614-
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)]
614+
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
615615
if not len(qs_key) and 'quant_type' not in qs_dict:
616-
raise ValueError("Expected packed or unpacked quant_state items, found neither")
617-
elif len(qs_key) != 1:
618-
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")
619-
616+
raise ValueError("Expected packed or unpacked quant_state items, found neither")
617+
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
618+
raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.")
619+
620620
# unpacking minor and non-tensor quant state items if necessary
621621
if len(qs_key) == 1:
622622
qs_key = qs_key[0]
623-
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))
623+
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key)))
624+
625+
qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes
626+
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
624627

625628
if 'nested_absmax' in qs_dict:
626629
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
@@ -654,7 +657,7 @@ def as_dict(self, packed=False):
654657
'quant_type': self.quant_type,
655658
'absmax': self.absmax,
656659
'blocksize': self.blocksize,
657-
'quant_map': self.code,
660+
'quant_map': self.code,
658661
'dtype': str(self.dtype).strip('torch.'),
659662
'shape': tuple(self.shape) if self.nested else None,
660663
}
@@ -673,7 +676,7 @@ def as_dict(self, packed=False):
673676
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
674677
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)
675678
return qs_packed_dict
676-
679+
677680
def to(self, device):
678681
# make sure the quantization state is on the right device
679682
self.absmax = self.absmax.to(device)
@@ -682,6 +685,7 @@ def to(self, device):
682685
self.state2.absmax = self.state2.absmax.to(device)
683686
self.state2.code = self.state2.code.to(device)
684687

688+
685689
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
686690
"""
687691
Quantize tensor A in blocks of size 4096 values.

bitsandbytes/nn/modules.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import Optional, TypeVar, Union, overload
5+
from typing import Any, Dict, Optional, TypeVar, Union, overload
66

77
import warnings
88
import torch
@@ -139,9 +139,10 @@ def forward(self, input: Tensor) -> Tensor:
139139

140140
return emb
141141

142+
142143
class Params4bit(torch.nn.Parameter):
143144

144-
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
145+
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
145146
if data is None:
146147
data = torch.empty(0)
147148

@@ -152,27 +153,16 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
152153
self.quant_state = quant_state
153154
self.data = data
154155
return self
155-
156+
156157
@classmethod
157-
def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
158-
data = state_dict.pop(prefix.rstrip('.'))
159-
160-
# extracting components for QuantState from state_dict
161-
qs_dict = {}
162-
for k, v in state_dict.items():
163-
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
164-
qs_dict[k] = v
165-
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
166-
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
167-
168-
if data.device.type != "cuda":
169-
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
170-
171-
cls.requires_grad = requires_grad,
172-
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
173-
174-
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
175-
return self, state_dict
158+
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
159+
self = torch.Tensor._make_subclass(cls, data.to(device))
160+
self.requires_grad = requires_grad
161+
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
162+
self.blocksize = self.quant_state.blocksize
163+
self.compress_statistics = self.quant_state.nested
164+
self.quant_type = self.quant_state.quant_type
165+
return self
176166

177167
def cuda(self, device):
178168
w = self.data.contiguous().half().cuda(device)
@@ -204,15 +194,16 @@ def to(self, *args, **kwargs):
204194
self.quant_state.to(device)
205195

206196
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
207-
requires_grad=self.requires_grad, quant_state=self.quant_state,
197+
requires_grad=self.requires_grad, quant_state=self.quant_state,
208198
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
209199
quant_type=self.quant_type)
210200

211201
return new_param
212202

203+
213204
class Linear4bit(nn.Linear):
214-
215-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
205+
206+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
216207
super().__init__(input_features, output_features, bias, device)
217208
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
218209
# self.persistent_buffers = [] # TODO consider as way to save quant state
@@ -246,18 +237,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
246237
for k, v in self.weight.quant_state.as_dict(packed=True).items():
247238
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
248239

249-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
250-
missing_keys, unexpected_keys, error_msgs):
251-
# Note: super()._load_from_state_dict() is not called here intentionally.
252-
if self.bias is not None:
253-
bias_data = state_dict.pop(prefix + "bias", None)
254-
self.bias.data = bias_data.to(self.bias.data.device)
255-
256-
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
257-
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
258-
)
259-
unexpected_keys.extend(state_dict.keys())
260-
261240
def forward(self, x: torch.Tensor):
262241
# weights are cast automatically as Int8Params, but the bias has to be cast manually
263242
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -280,10 +259,12 @@ def forward(self, x: torch.Tensor):
280259

281260
return out
282261

262+
283263
class LinearFP4(Linear4bit):
284-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
264+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
285265
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
286266

267+
287268
class LinearNF4(Linear4bit):
288269
''' Implements the NF4 data type.
289270
@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
295276
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
296277
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
297278
'''
298-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
279+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
299280
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
300281

301282

tests/test_linear4bit.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch
88

99
import bitsandbytes as bnb
10-
from bitsandbytes import functional as F
11-
from bitsandbytes.nn.modules import Linear4bit
1210

1311

1412
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
4139

4240
# saving to state_dict:
4341
sd = linear_q.state_dict()
44-
42+
# restoring from state_dict:
43+
bias_data2 = sd.pop("bias", None)
44+
weight_data2 = sd.pop("weight")
45+
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
4546
# creating new layer with same params:
4647
linear_q2 = bnb.nn.Linear4bit(
4748
linear.in_features,
@@ -53,15 +54,17 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
5354
device=device, # TODO create on meta device to save loading time
5455
)
5556
# loading weights from state_dict:
56-
linear_q2.load_state_dict(sd)
57+
linear_q2.weight = weight2.to(device)
58+
if bias:
59+
linear_q2.bias = torch.nn.Parameter(bias_data2)
5760

5861
# MATCHING
5962
a, b = linear_q.weight, linear_q2.weight
6063

6164
assert a.device == b.device
6265
assert a.dtype == b.dtype
6366
assert torch.equal(a, b)
64-
67+
6568
q0 = a.quant_state
6669
q1 = b.quant_state
6770
for attr in ('code', 'dtype', 'blocksize', 'absmax'):

0 commit comments

Comments
 (0)