77import torch
88
99import bitsandbytes as bnb
10- from tests .helpers import TRUE_FALSE , torch_load_from_buffer , torch_save_to_buffer
10+ from tests .helpers import TRUE_FALSE , get_available_devices , id_formatter , torch_load_from_buffer , torch_save_to_buffer
1111
1212storage = {
1313 "uint8" : torch .uint8 ,
1717}
1818
1919
20+ @pytest .mark .parametrize ("device" , get_available_devices ())
2021@pytest .mark .parametrize ("quant_storage" , ["uint8" , "float16" , "bfloat16" , "float32" ])
21- @pytest .mark .parametrize ("bias" , TRUE_FALSE )
22- @pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE )
22+ @pytest .mark .parametrize ("bias" , TRUE_FALSE , ids = id_formatter ( "bias" ) )
23+ @pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ( "compress_statistics" ) )
2324@pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
24- @pytest .mark .parametrize ("save_before_forward" , TRUE_FALSE )
25- def test_linear_serialization (quant_type , compress_statistics , bias , quant_storage , save_before_forward ):
25+ @pytest .mark .parametrize ("save_before_forward" , TRUE_FALSE , ids = id_formatter ("save_before_forward" ))
26+ def test_linear_serialization (device , quant_type , compress_statistics , bias , quant_storage , save_before_forward ):
27+ if device == "cpu" :
28+ pytest .xfail ("Dequantization is not yet implemented for CPU" )
29+
2630 original_dtype = torch .float16
2731 compute_dtype = None
28- device = "cuda"
2932 layer_shape = (300 , 400 )
3033
3134 linear = torch .nn .Linear (* layer_shape , dtype = original_dtype , device = "cpu" ) # original layer
@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
5255 # restoring from state_dict:
5356 bias_data2 = sd .pop ("bias" , None )
5457 weight_data2 = sd .pop ("weight" )
55- weight2 = bnb .nn .Params4bit .from_prequantized (quantized_stats = sd , data = weight_data2 )
58+ weight2 = bnb .nn .Params4bit .from_prequantized (quantized_stats = sd , data = weight_data2 , device = device )
5659
5760 # creating new layer with same params:
5861 linear_q2 = bnb .nn .Linear4bit (
@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
174177 assert size_ratio < target_compression , ratio_error_msg
175178
176179
177- def test_copy_param ():
178- tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
179- param = bnb .nn .Params4bit (data = tensor , requires_grad = False ).cuda (0 )
180+ @pytest .mark .parametrize ("device" , get_available_devices ())
181+ @pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
182+ @pytest .mark .parametrize ("blocksize" , [64 , 128 ])
183+ @pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
184+ def test_copy_param (device , quant_type , blocksize , compress_statistics ):
185+ if device == "cpu" :
186+ if compress_statistics :
187+ pytest .skip ("Currently segfaults on CPU" )
188+ if quant_type == "fp4" :
189+ pytest .xfail ("FP4 not supported on CPU" )
190+
191+ tensor = torch .linspace (1 , blocksize , blocksize )
192+ param = bnb .nn .Params4bit (
193+ data = tensor ,
194+ quant_type = quant_type ,
195+ blocksize = blocksize ,
196+ compress_statistics = compress_statistics ,
197+ requires_grad = False ,
198+ ).to (device )
180199
181200 shallow_copy_param = copy .copy (param )
182201 assert param .quant_state is shallow_copy_param .quant_state
183202 assert param .data .data_ptr () == shallow_copy_param .data .data_ptr ()
184203
185204
186- def test_deepcopy_param ():
187- tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ])
188- param = bnb .nn .Params4bit (data = tensor , requires_grad = False ).cuda (0 )
205+ @pytest .mark .parametrize ("device" , get_available_devices ())
206+ @pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
207+ @pytest .mark .parametrize ("blocksize" , [64 , 128 ])
208+ @pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
209+ def test_deepcopy_param (device , quant_type , blocksize , compress_statistics ):
210+ if device == "cpu" :
211+ if compress_statistics :
212+ pytest .skip ("Currently segfaults on CPU" )
213+ if quant_type == "fp4" :
214+ pytest .xfail ("FP4 not supported on CPU" )
215+
216+ tensor = torch .linspace (1 , blocksize , blocksize )
217+ param = bnb .nn .Params4bit (
218+ data = tensor ,
219+ quant_type = quant_type ,
220+ blocksize = blocksize ,
221+ compress_statistics = compress_statistics ,
222+ requires_grad = False ,
223+ ).to (device )
189224 dict_keys_before = set (param .__dict__ .keys ())
190225 copy_param = copy .deepcopy (param )
191226 dict_keys_after = set (param .__dict__ .keys ())
@@ -199,12 +234,27 @@ def test_deepcopy_param():
199234 assert dict_keys_before == dict_keys_copy
200235
201236
202- def test_params4bit_real_serialization ():
203- original_tensor = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = torch .float32 )
204- original_param = bnb .nn .Params4bit (data = original_tensor , quant_type = "fp4" )
237+ @pytest .mark .parametrize ("device" , get_available_devices ())
238+ @pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
239+ @pytest .mark .parametrize ("blocksize" , [64 , 128 ])
240+ @pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
241+ def test_params4bit_real_serialization (device , quant_type , blocksize , compress_statistics ):
242+ if device == "cpu" :
243+ if compress_statistics :
244+ pytest .skip ("Currently segfaults on CPU" )
245+ if quant_type == "fp4" :
246+ pytest .xfail ("FP4 not supported on CPU" )
247+
248+ original_tensor = torch .linspace (1 , blocksize , blocksize , dtype = torch .float32 )
249+ original_param = bnb .nn .Params4bit (
250+ data = original_tensor ,
251+ quant_type = quant_type ,
252+ blocksize = blocksize ,
253+ compress_statistics = compress_statistics ,
254+ )
205255 dict_keys_before = set (original_param .__dict__ .keys ())
206256
207- original_param .cuda ( 0 ) # move to CUDA to trigger quantization
257+ original_param .to ( device ) # change device to trigger quantization
208258
209259 serialized_param = pickle .dumps (original_param )
210260 deserialized_param = pickle .loads (serialized_param )
0 commit comments