26
26
27
27
28
28
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
29
- @pytest .mark .parametrize ("randomized " , (True , False ))
30
- def test_correctness_linear (type , randomized ):
29
+ @pytest .mark .parametrize ("randomize " , (True , False ))
30
+ def test_correctness_linear (type , randomize ):
31
31
size = (4 , 8 )
32
32
module = torch .nn .Linear (* size , bias = True )
33
- scheme = TransformScheme (type = type , randomized = randomized )
33
+ scheme = TransformScheme (type = type , randomize = randomize )
34
34
factory = TransformFactory .from_scheme (scheme , name = "" )
35
35
36
36
input_tfm = factory .create_transform (
@@ -55,8 +55,8 @@ def test_correctness_linear(type, randomized):
55
55
56
56
57
57
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
58
- @pytest .mark .parametrize ("randomized " , (True , False ))
59
- def test_correctness_model (type , randomized , model_apply , offload = False ):
58
+ @pytest .mark .parametrize ("randomize " , (True , False ))
59
+ def test_correctness_model (type , randomize , model_apply , offload = False ):
60
60
# load model
61
61
model = model_apply [0 ]
62
62
if offload :
@@ -71,7 +71,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
71
71
# apply transforms
72
72
config = TransformConfig (
73
73
config_groups = {
74
- "" : TransformScheme (type = type , randomized = randomized , apply = model_apply [1 ])
74
+ "" : TransformScheme (type = type , randomize = randomize , apply = model_apply [1 ])
75
75
}
76
76
)
77
77
apply_transform_config (model , config )
@@ -84,6 +84,6 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
84
84
@requires_gpu
85
85
@requires_accelerate ()
86
86
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
87
- @pytest .mark .parametrize ("randomized " , (True , False ))
88
- def test_correctness_model_offload (type , randomized , model_apply ):
89
- test_correctness_model (type , randomized , model_apply , offload = True )
87
+ @pytest .mark .parametrize ("randomize " , (True , False ))
88
+ def test_correctness_model_offload (type , randomize , model_apply ):
89
+ test_correctness_model (type , randomize , model_apply , offload = True )
0 commit comments