27
27
28
28
29
29
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
30
- @pytest .mark .parametrize ("randomized " , (True , False ))
30
+ @pytest .mark .parametrize ("randomize " , (True , False ))
31
31
@pytest .mark .parametrize ("head_dim" , (None , 2 , 4 ))
32
32
@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
33
- def test_correctness_linear (type , randomized , head_dim , input_batch_size ):
33
+ def test_correctness_linear (type , randomize , head_dim , input_batch_size ):
34
34
size = (4 , 8 )
35
35
module = torch .nn .Linear (* size , bias = False )
36
- scheme = TransformScheme (type = type , randomized = randomized , head_dim = head_dim )
36
+ scheme = TransformScheme (type = type , randomize = randomize , head_dim = head_dim )
37
37
factory = TransformFactory .from_scheme (scheme , name = "" )
38
38
39
39
input_tfm = factory .create_transform (
@@ -58,10 +58,10 @@ def test_correctness_linear(type, randomized, head_dim, input_batch_size):
58
58
59
59
60
60
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
61
- @pytest .mark .parametrize ("randomized " , (True , False ))
61
+ @pytest .mark .parametrize ("randomize " , (True , False ))
62
62
@pytest .mark .parametrize ("embed_loc" , ("weight_output" , "output" ))
63
63
@pytest .mark .parametrize ("linear_loc" , ("input" , "weight_input" ))
64
- def test_correctness_embedding (type , randomized , embed_loc , linear_loc ):
64
+ def test_correctness_embedding (type , randomize , embed_loc , linear_loc ):
65
65
model = torch .nn .Sequential (
66
66
torch .nn .Embedding (2 , 4 ),
67
67
torch .nn .Linear (4 , 8 , bias = False ),
@@ -74,7 +74,7 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
74
74
config_groups = {
75
75
"" : TransformScheme (
76
76
type = type ,
77
- randomized = randomized ,
77
+ randomize = randomize ,
78
78
apply = [
79
79
TransformArgs (targets = "Embedding" , location = embed_loc ),
80
80
TransformArgs (targets = "Linear" , location = linear_loc , inverse = True ),
@@ -90,10 +90,10 @@ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
90
90
91
91
92
92
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
93
- @pytest .mark .parametrize ("randomized " , (True , False ))
93
+ @pytest .mark .parametrize ("randomize " , (True , False ))
94
94
@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
95
95
def test_correctness_model (
96
- type , randomized , input_batch_size , model_apply , offload = False
96
+ type , randomize , input_batch_size , model_apply , offload = False
97
97
):
98
98
# load model
99
99
model = model_apply [0 ]
@@ -109,7 +109,7 @@ def test_correctness_model(
109
109
# apply transforms
110
110
config = TransformConfig (
111
111
config_groups = {
112
- "" : TransformScheme (type = type , randomized = randomized , apply = model_apply [1 ])
112
+ "" : TransformScheme (type = type , randomize = randomize , apply = model_apply [1 ])
113
113
}
114
114
)
115
115
apply_transform_config (model , config )
@@ -122,19 +122,17 @@ def test_correctness_model(
122
122
@requires_gpu
123
123
@requires_accelerate ()
124
124
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
125
- @pytest .mark .parametrize ("randomized " , (True , False ))
125
+ @pytest .mark .parametrize ("randomize " , (True , False ))
126
126
@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
127
- def test_correctness_model_offload (type , randomized , input_batch_size , model_apply ):
128
- test_correctness_model (
129
- type , randomized , input_batch_size , model_apply , offload = True
130
- )
127
+ def test_correctness_model_offload (type , randomize , input_batch_size , model_apply ):
128
+ test_correctness_model (type , randomize , input_batch_size , model_apply , offload = True )
131
129
132
130
133
131
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
134
- @pytest .mark .parametrize ("randomized " , (True , False ))
132
+ @pytest .mark .parametrize ("randomize " , (True , False ))
135
133
@pytest .mark .parametrize ("head_dim" , (4 , 8 ))
136
134
@pytest .mark .parametrize ("input_batch_size" , (1 , 5 , 17 ))
137
- def test_correctness_attention_heads (type , randomized , head_dim , input_batch_size ):
135
+ def test_correctness_attention_heads (type , randomize , head_dim , input_batch_size ):
138
136
hidden_size = 64
139
137
num_attention_heads = 8
140
138
@@ -151,7 +149,7 @@ def test_correctness_attention_heads(type, randomized, head_dim, input_batch_siz
151
149
config_groups = {
152
150
"" : TransformScheme (
153
151
type = type ,
154
- randomized = randomized ,
152
+ randomize = randomize ,
155
153
head_dim = head_dim ,
156
154
apply = [
157
155
TransformArgs (targets = "v_proj" , location = "weight_output" ),
0 commit comments