1+ import copy
12import os
3+ import pickle
24from tempfile import TemporaryDirectory
35
46import pytest
810from tests .helpers import TRUE_FALSE
911
1012storage = {
11- ' uint8' : torch .uint8 ,
12- ' float16' : torch .float16 ,
13- ' bfloat16' : torch .bfloat16 ,
14- ' float32' : torch .float32
13+ " uint8" : torch .uint8 ,
14+ " float16" : torch .float16 ,
15+ " bfloat16" : torch .bfloat16 ,
16+ " float32" : torch .float32 ,
1517}
1618
17- @pytest .mark .parametrize ("quant_storage" , ['uint8' , 'float16' , 'bfloat16' , 'float32' ])
19+
20+ @pytest .mark .parametrize ("quant_storage" , ["uint8" , "float16" , "bfloat16" , "float32" ])
1821@pytest .mark .parametrize ("bias" , TRUE_FALSE )
1922@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE )
2023@pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
@@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
2427 device = "cuda"
2528 layer_shape = (300 , 400 )
2629
27- linear = torch .nn .Linear (* layer_shape , dtype = original_dtype , device = "cpu" ) # original layer
30+ linear = torch .nn .Linear (
31+ * layer_shape , dtype = original_dtype , device = "cpu"
32+ ) # original layer
2833
2934 # Quantizing original layer
3035 linear_q = bnb .nn .Linear4bit (
@@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
3641 quant_type = quant_type ,
3742 device = "meta" ,
3843 )
39- new_weight = bnb .nn .Params4bit (data = linear .weight , quant_type = quant_type , requires_grad = False )
44+ new_weight = bnb .nn .Params4bit (
45+ data = linear .weight , quant_type = quant_type , requires_grad = False
46+ )
4047 linear_q .weight = new_weight
4148 if bias :
4249 linear_q .bias = torch .nn .Parameter (linear .bias )
@@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
8087 quant_storage = storage [quant_storage ],
8188 device = "meta" ,
8289 )
83- linear_qs .weight = bnb .nn .Params4bit (data = linear .weight , requires_grad = False , quant_type = quant_type , quant_storage = storage [quant_storage ])
90+ linear_qs .weight = bnb .nn .Params4bit (
91+ data = linear .weight ,
92+ requires_grad = False ,
93+ quant_type = quant_type ,
94+ quant_storage = storage [quant_storage ],
95+ )
8496 if bias :
8597 linear_qs .bias = torch .nn .Parameter (linear .bias )
8698 linear_qs = linear_qs .to (device )
@@ -91,15 +103,15 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
91103
92104 q0 = a .quant_state
93105 q1 = b .quant_state
94- for attr in (' code' , ' dtype' , ' blocksize' , ' absmax' ):
106+ for attr in (" code" , " dtype" , " blocksize" , " absmax" ):
95107 c , d = getattr (q0 , attr ), getattr (q1 , attr )
96108 if isinstance (c , torch .Tensor ):
97109 assert torch .equal (c , d )
98110 else :
99111 assert c == d , f"{ c } != { d } "
100112
101113 if q0 .state2 is not None :
102- for attr in (' code' , ' dtype' , ' blocksize' , ' absmax' ):
114+ for attr in (" code" , " dtype" , " blocksize" , " absmax" ):
103115 c , d = getattr (q0 .state2 , attr ), getattr (q1 .state2 , attr )
104116 if isinstance (c , torch .Tensor ):
105117 assert torch .equal (c , d )
@@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
125137 assert torch .equal (a , c )
126138
127139 # Test moving to CPU and back to GPU
128- linear_q2 .to (' cpu' )
140+ linear_q2 .to (" cpu" )
129141 linear_q2 .to (device )
130142 d = linear_qs (x )
131143 assert c .dtype == d .dtype
@@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
139151 torch .save (linear .state_dict (), state_path )
140152 torch .save (linear_q .state_dict (), state_path_4bit )
141153
142- size_orig , size_4 = os .path .getsize (state_path ), os .path .getsize (
143- state_path_4bit
154+ size_orig , size_4 = (
155+ os .path .getsize (state_path ),
156+ os .path .getsize (state_path_4bit ),
144157 )
145158 size_ratio = size_4 / size_orig
146- target_compression = 0.143 if original_dtype == torch .float32 else 0.29 # these numbers get lower as weight shape increases
159+ target_compression = (
160+ 0.143 if original_dtype == torch .float32 else 0.29
161+ ) # these numbers get lower as weight shape increases
147162 ratio_error_msg = f"quantized_size { size_4 :,} is larger on disk than { target_compression :.2%} of original size { size_orig :,} "
148163 assert size_ratio < target_compression , ratio_error_msg
164+
165+
166+ def test_copy_param ():
167+ tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
168+ param = bnb .nn .Params4bit (data = tensor , requires_grad = False ).cuda (0 )
169+
170+ shallow_copy_param = copy .copy (param )
171+ assert param .quant_state is shallow_copy_param .quant_state
172+ assert param .data .data_ptr () == shallow_copy_param .data .data_ptr ()
173+
174+
175+ def test_deepcopy_param ():
176+ tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
177+ param = bnb .nn .Params4bit (data = tensor , requires_grad = False ).cuda (0 )
178+ copy_param = copy .deepcopy (param )
179+ assert param .quant_state is not copy_param .quant_state
180+ assert param .data .data_ptr () != copy_param .data .data_ptr ()
181+
182+
183+ def test_params4bit_real_serialization ():
184+ original_tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = torch .float32 )
185+ original_param = bnb .nn .Params4bit (data = original_tensor , quant_type = "fp4" )
186+
187+ original_param .cuda (0 ) # move to CUDA to trigger quantization
188+
189+ serialized_param = pickle .dumps (original_param )
190+ deserialized_param = pickle .loads (serialized_param )
191+
192+ assert torch .equal (original_param .data , deserialized_param .data )
193+ assert original_param .requires_grad == deserialized_param .requires_grad == False
194+ assert original_param .quant_type == deserialized_param .quant_type
195+ assert original_param .blocksize == deserialized_param .blocksize
196+ assert original_param .compress_statistics == deserialized_param .compress_statistics
197+ assert original_param .quant_state == deserialized_param .quant_state
0 commit comments