@@ -67,13 +67,15 @@ def test_linear_no_igemmlt(device):
6767
6868@pytest .mark .parametrize ("device" , get_available_devices ())
6969@pytest .mark .parametrize ("has_fp16_weights" , TRUE_FALSE , ids = id_formatter ("has_fp16_weights" ))
70+ @pytest .mark .parametrize ("threshold" , [0.0 , 6.0 ], ids = id_formatter ("threshold" ))
7071@pytest .mark .parametrize ("serialize_before_forward" , TRUE_FALSE , ids = id_formatter ("serialize_before_forward" ))
7172@pytest .mark .parametrize ("deserialize_before_cuda" , TRUE_FALSE , ids = id_formatter ("deserialize_before_cuda" ))
7273@pytest .mark .parametrize ("save_before_forward" , TRUE_FALSE , ids = id_formatter ("save_before_forward" ))
7374@pytest .mark .parametrize ("load_before_cuda" , TRUE_FALSE , ids = id_formatter ("load_before_cuda" ))
7475def test_linear_serialization (
7576 device ,
7677 has_fp16_weights ,
78+ threshold ,
7779 serialize_before_forward ,
7880 deserialize_before_cuda ,
7981 save_before_forward ,
@@ -92,7 +94,7 @@ def test_linear_serialization(
9294 linear .out_features ,
9395 linear .bias is not None ,
9496 has_fp16_weights = has_fp16_weights ,
95- threshold = 6.0 ,
97+ threshold = threshold ,
9698 )
9799
98100 linear_custom .weight = bnb .nn .Int8Params (
@@ -137,7 +139,7 @@ def test_linear_serialization(
137139 linear .out_features ,
138140 linear .bias is not None ,
139141 has_fp16_weights = has_fp16_weights ,
140- threshold = 6.0 ,
142+ threshold = threshold ,
141143 )
142144
143145 if deserialize_before_cuda :
0 commit comments