Skip to content

Commit ffd46ce

Browse files
committed
fixes for init and tests
1 parent 726f147 commit ffd46ce

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def as_dict(self, packed=False):
659659
'blocksize': self.blocksize,
660660
'quant_map': self.code,
661661
'dtype': str(self.dtype).strip('torch.'),
662-
'shape': tuple(self.shape) if self.nested else None,
662+
'shape': tuple(self.shape),
663663
}
664664
if self.nested:
665665
qs_dict.update({

tests/test_linear4bit.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
2020
device = "cuda"
2121
layer_shape = (300, 400)
2222

23-
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype) # original layer
23+
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
2424

2525
# Quantizing original layer
2626
linear_q = bnb.nn.Linear4bit(
@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
3030
compute_dtype=compute_dtype,
3131
compress_statistics=compress_statistics,
3232
quant_type=quant_type,
33-
device=device,
33+
device="meta", # TODO: consider both CPU, meta and CUDA creation
3434
)
3535
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
36-
linear_q.weight = new_weight.to(device)
36+
linear_q.weight = new_weight
3737
if bias:
38-
linear_q.bias.data = linear.bias.data.to(device)
38+
linear_q.bias = torch.nn.Parameter(linear.bias)
39+
linear_q = linear_q.to(device)
3940

4041
# saving to state_dict:
4142
sd = linear_q.state_dict()
43+
4244
# restoring from state_dict:
4345
bias_data2 = sd.pop("bias", None)
4446
weight_data2 = sd.pop("weight")
4547
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
48+
4649
# creating new layer with same params:
4750
linear_q2 = bnb.nn.Linear4bit(
4851
linear.in_features,
@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
5154
compute_dtype=compute_dtype,
5255
compress_statistics=compress_statistics,
5356
quant_type=quant_type,
54-
device=device, # TODO create on meta device to save loading time
57+
device="meta",
5558
)
5659
# loading weights from state_dict:
57-
linear_q2.weight = weight2.to(device)
60+
linear_q2.weight = weight2
5861
if bias:
5962
linear_q2.bias = torch.nn.Parameter(bias_data2)
63+
linear_q2 = linear_q2.to(device)
6064

6165
# MATCHING
6266
a, b = linear_q.weight, linear_q2.weight
@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
107111
state_path_4bit
108112
)
109113
size_ratio = size_4 / size_orig
110-
target_compression = 0.143 if original_dtype == torch.float32 else 0.285
114+
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
111115
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
112116
assert size_ratio < target_compression, ratio_error_msg

0 commit comments

Comments
 (0)