@@ -119,43 +119,6 @@ def test_correctness_model(
119
119
assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
120
120
121
121
122
- @pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
123
- @pytest .mark .parametrize ("randomize" , (True , False ))
124
- @pytest .mark .parametrize ("head_dim" , (4 , 8 ))
125
- def test_correctness_attention_heads (type , randomize , head_dim ):
126
- hidden_size = 64
127
- num_attention_heads = 8
128
-
129
- attention = MockAttention (
130
- hidden_size = hidden_size ,
131
- num_attention_heads = num_attention_heads ,
132
- num_key_value_heads = head_dim ,
133
- )
134
-
135
- input = torch .rand (17 , 5 , hidden_size )
136
- true_output = attention (input )
137
-
138
- config = TransformConfig (
139
- config_groups = {
140
- "" : TransformScheme (
141
- type = type ,
142
- randomize = randomize ,
143
- head_dim = head_dim ,
144
- apply = [
145
- TransformArgs (targets = "v_proj" , location = "weight_output" ),
146
- TransformArgs (
147
- targets = "o_proj" , location = "weight_input" , inverse = True
148
- ),
149
- ],
150
- )
151
- }
152
- )
153
- apply_transform_config (attention , config )
154
-
155
- output = attention (input )
156
- assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
157
-
158
-
159
122
@requires_gpu
160
123
@requires_accelerate ()
161
124
@pytest .mark .parametrize ("type" , ("hadamard" , "random-hadamard" ))
0 commit comments