@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
2020 device = "cuda"
2121 layer_shape = (300 , 400 )
2222
23- linear = torch .nn .Linear (* layer_shape , dtype = original_dtype ) # original layer
23+ linear = torch .nn .Linear (* layer_shape , dtype = original_dtype , device = "cpu" ) # original layer
2424
2525 # Quantizing original layer
2626 linear_q = bnb .nn .Linear4bit (
@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
3030 compute_dtype = compute_dtype ,
3131 compress_statistics = compress_statistics ,
3232 quant_type = quant_type ,
33- device = device ,
33+ device = "meta" , # TODO: consider both CPU, meta and CUDA creation
3434 )
3535 new_weight = bnb .nn .Params4bit (data = linear .weight , requires_grad = False )
36- linear_q .weight = new_weight . to ( device )
36+ linear_q .weight = new_weight
3737 if bias :
38- linear_q .bias .data = linear .bias .data .to (device )
38+ linear_q .bias = torch .nn .Parameter (linear .bias )
39+ linear_q = linear_q .to (device )
3940
4041 # saving to state_dict:
4142 sd = linear_q .state_dict ()
43+
4244 # restoring from state_dict:
4345 bias_data2 = sd .pop ("bias" , None )
4446 weight_data2 = sd .pop ("weight" )
4547 weight2 = bnb .nn .Params4bit .from_prequantized (quantized_stats = sd , data = weight_data2 )
48+
4649 # creating new layer with same params:
4750 linear_q2 = bnb .nn .Linear4bit (
4851 linear .in_features ,
@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
5154 compute_dtype = compute_dtype ,
5255 compress_statistics = compress_statistics ,
5356 quant_type = quant_type ,
54- device = device , # TODO create on meta device to save loading time
57+ device = " meta" ,
5558 )
5659 # loading weights from state_dict:
57- linear_q2 .weight = weight2 . to ( device )
60+ linear_q2 .weight = weight2
5861 if bias :
5962 linear_q2 .bias = torch .nn .Parameter (bias_data2 )
63+ linear_q2 = linear_q2 .to (device )
6064
6165 # MATCHING
6266 a , b = linear_q .weight , linear_q2 .weight
@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
107111 state_path_4bit
108112 )
109113 size_ratio = size_4 / size_orig
110- target_compression = 0.143 if original_dtype == torch .float32 else 0.285
114+ target_compression = 0.143 if original_dtype == torch .float32 else 0.29 # these numbers get lower as weight shape increases
111115 ratio_error_msg = f"quantized_size { size_4 :,} is larger on disk than { target_compression :.2%} of original size { size_orig :,} "
112116 assert size_ratio < target_compression , ratio_error_msg
0 commit comments