Skip to content

Commit c8f564d

Browse files
authored
Merge pull request #868 from poedator/fix_1108
Fix for 4bit without compress_statistics.
2 parents bbbed83 + 079d7af commit c8f564d

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

bitsandbytes/functional.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
642642
blocksize=qs_dict['blocksize'],
643643
code=qs_dict['quant_map'].to(device),
644644
dtype=getattr(torch, qs_dict['dtype']),
645-
shape=torch.Size(qs_dict['shape']),
645+
shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None,
646646
offset=offset,
647647
state2=state2,
648648
)
@@ -651,27 +651,28 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
651651
def as_dict(self, packed=False):
652652
"""
653653
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
654-
param: packed -- returns dict[str, torch.Tensor] for state_dict
654+
param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
655655
"""
656656
qs_dict = {
657657
'quant_type': self.quant_type,
658658
'absmax': self.absmax,
659659
'blocksize': self.blocksize,
660660
'quant_map': self.code,
661661
'dtype': str(self.dtype).strip('torch.'),
662-
'shape': tuple(self.shape) if self.nested else None,
662+
'shape': tuple(self.shape),
663663
}
664664
if self.nested:
665665
qs_dict.update({
666666
'nested_absmax': self.state2.absmax,
667667
'nested_blocksize': self.state2.blocksize,
668-
'nested_quant_map': self.state2.code,
668+
'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
669669
'nested_dtype': str(self.state2.dtype).strip('torch.'),
670670
'nested_offset': self.offset.item(),
671671
})
672672
if not packed:
673673
return qs_dict
674674

675+
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
675676
qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)}
676677
non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)}
677678
qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict)

tests/test_linear4bit.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
2020
device = "cuda"
2121
layer_shape = (300, 400)
2222

23-
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
23+
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
2424

2525
# Quantizing original layer
2626
linear_q = bnb.nn.Linear4bit(
@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
3030
compute_dtype=compute_dtype,
3131
compress_statistics=compress_statistics,
3232
quant_type=quant_type,
33-
device=device,
33+
device="meta",
3434
)
3535
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
36-
linear_q.weight = new_weight.to(device)
36+
linear_q.weight = new_weight
3737
if bias:
38-
linear_q.bias.data = linear.bias.data.to(device)
38+
linear_q.bias = torch.nn.Parameter(linear.bias)
39+
linear_q = linear_q.to(device)
3940

4041
# saving to state_dict:
4142
sd = linear_q.state_dict()
43+
4244
# restoring from state_dict:
4345
bias_data2 = sd.pop("bias", None)
4446
weight_data2 = sd.pop("weight")
4547
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
48+
4649
# creating new layer with same params:
4750
linear_q2 = bnb.nn.Linear4bit(
4851
linear.in_features,
@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
5154
compute_dtype=compute_dtype,
5255
compress_statistics=compress_statistics,
5356
quant_type=quant_type,
54-
device=device, # TODO create on meta device to save loading time
57+
device="meta",
5558
)
5659
# loading weights from state_dict:
57-
linear_q2.weight = weight2.to(device)
60+
linear_q2.weight = weight2
5861
if bias:
5962
linear_q2.bias = torch.nn.Parameter(bias_data2)
63+
linear_q2 = linear_q2.to(device)
6064

6165
# MATCHING
6266
a, b = linear_q.weight, linear_q2.weight
@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
107111
state_path_4bit
108112
)
109113
size_ratio = size_4 / size_orig
110-
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
114+
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
111115
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
112116
assert size_ratio < target_compression, ratio_error_msg

0 commit comments

Comments
 (0)