Skip to content

Commit 76b40a5

Browse files
committed
save/load via state_dict now
1 parent 965fd5d commit 76b40a5

File tree

4 files changed

+60
-36
lines changed

4 files changed

+60
-36
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ dmypy.json
133133

134134
dependencies
135135
cuda_build
136+
.vscode/*

bitsandbytes/functional.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,17 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
568568
return out
569569

570570
class QuantState:
571-
"""container for quantizationstate components to work with Params4bit and similar clases"""
571+
"""container for quantization state components to work with Params4bit and similar clases"""
572+
valid_quant_types = ('fp4', 'nf4')
573+
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
574+
valid_qs_keys = ['absmax', 'code', 'nested_absmax', 'nested_code', 'quant_state',
575+
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']
576+
577+
572578
def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
573579
self.absmax = absmax
574580
self.shape = shape
575-
self.code = code
581+
self.code = code # TODO consider renaming to `buckets / centroids / scale`
576582
self.dtype = dtype
577583
self.blocksize = blocksize
578584
self.quant_type = quant_type
@@ -596,26 +602,26 @@ def __get_item__(self, idx):
596602
@classmethod
597603
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState':
598604
"""
599-
unpacks dict of tensors into QuantState
605+
unpacks components of state_dict into QuantState
600606
where necessary, convert into strings, torch.dtype, ints, etc.
601607
602-
quant_state_dict may contain item with non-tensor components with key like
603-
`...weight.quant_state.bitsandbytes__[nf4/fp4]`
604-
it is detected with key strored in qs_key, and then unpacked
608+
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
609+
610+
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
605611
"""
606612

607613
# unpacking tensor with non-tensor components
608-
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
609-
assert len(qs_key) == 1 or not qs_key and 'quant_type' in qs_dict, \
610-
f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: {tuple(qs_dict.keys())}"
614+
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)]
615+
if not len(qs_key) and 'quant_type' not in qs_dict:
616+
raise ValueError("Expected packed or unpacked quant_state items, found neither")
617+
elif len(qs_key) != 1:
618+
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")
619+
620+
# unpacking minor and non-tensor quant state items if necessary
611621
if len(qs_key) == 1:
612622
qs_key = qs_key[0]
613-
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
614-
f"invalid qs_key value {qs_key}"
615623
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))
616624

617-
qs_dict = {k.split('.')[-1]:v for k, v in qs_dict.items()} # strip prefixes
618-
619625
if 'nested_absmax' in qs_dict:
620626
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
621627
state2 = cls(
@@ -873,7 +879,6 @@ def get_4bit_type(typename, device=None, blocksize=64):
873879
return data.to(device)
874880

875881

876-
877882
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
878883
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
879884

bitsandbytes/nn/modules.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,25 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
154154
return self
155155

156156
@classmethod
157-
def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs):
158-
self = torch.Tensor._make_subclass(cls, data.to(device))
159-
self.requires_grad = requires_grad
160-
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
161-
self.blocksize = self.quant_state.blocksize
162-
self.compress_statistics = self.quant_state.nested
163-
self.quant_type = self.quant_state.quant_type
164-
return self
157+
def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
158+
data = state_dict.pop(prefix.rstrip('.'))
159+
160+
# extracting components for QuantState from state_dict
161+
qs_dict = {}
162+
for k, v in state_dict.items():
163+
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
164+
qs_dict[k] = v
165+
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
166+
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
167+
168+
if data.device.type != "cuda":
169+
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
170+
171+
cls.requires_grad = requires_grad,
172+
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
173+
174+
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
175+
return self, state_dict
165176

166177
def cuda(self, device):
167178
w = self.data.contiguous().half().cuda(device)
@@ -200,9 +211,11 @@ def to(self, *args, **kwargs):
200211
return new_param
201212

202213
class Linear4bit(nn.Linear):
214+
203215
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
204216
super().__init__(input_features, output_features, bias, device)
205217
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
218+
# self.persistent_buffers = [] # TODO consider as way to save quant state
206219
self.compute_dtype = compute_dtype
207220
self.compute_type_is_set = False
208221

@@ -233,6 +246,18 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
233246
for k, v in self.weight.quant_state.as_dict(packed=True).items():
234247
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
235248

249+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
250+
missing_keys, unexpected_keys, error_msgs):
251+
# Note: super()._load_from_state_dict() is not called here intentionally.
252+
if self.bias is not None:
253+
bias_data = state_dict.pop(prefix + "bias", None)
254+
self.bias.data = bias_data.to(self.bias.data.device)
255+
256+
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
257+
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
258+
)
259+
unexpected_keys.extend(state_dict.keys())
260+
236261
def forward(self, x: torch.Tensor):
237262
# weights are cast automatically as Int8Params, but the bias has to be cast manually
238263
if self.bias is not None and self.bias.dtype != x.dtype:

tests/test_linear4bit.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"quant_type, compress_statistics, bias",
1717
list(product(["nf4", "fp4"], [False, True], [False, True])),
1818
)
19-
def test_linear4_state_dict(quant_type, compress_statistics, bias):
19+
def test_linear_serialization(quant_type, compress_statistics, bias):
2020
original_dtype = torch.float16
2121
compute_dtype = None
2222
device = "cuda"
@@ -39,30 +39,23 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
3939
if bias:
4040
linear_q.bias.data = linear.bias.data.to(device)
4141

42+
# saving to state_dict:
4243
sd = linear_q.state_dict()
4344

44-
# restoring from state_dict:
45-
46-
sd = linear_q.state_dict()
47-
bias_data2 = sd.pop("bias", None)
48-
weight_data2 = sd.pop("weight")
49-
50-
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
51-
45+
# creating new layer with same params:
5246
linear_q2 = bnb.nn.Linear4bit(
5347
linear.in_features,
5448
linear.out_features,
5549
bias=bias,
5650
compute_dtype=compute_dtype,
5751
compress_statistics=compress_statistics,
5852
quant_type=quant_type,
59-
device='meta',
53+
device=device, # TODO create on meta device to save loading time
6054
)
61-
linear_q2.weight = weight2.to(device)
62-
if bias:
63-
linear_q2.bias = torch.nn.Parameter(bias_data2)
55+
# loading weights from state_dict:
56+
linear_q2.load_state_dict(sd)
6457

65-
# matching
58+
# MATCHING
6659
a, b = linear_q.weight, linear_q2.weight
6760

6861
assert a.device == b.device

0 commit comments

Comments
 (0)