Skip to content

Commit c3ce35a

Browse files
drop redundant test
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 58257eb commit c3ce35a

File tree

1 file changed

+0
-37
lines changed

1 file changed

+0
-37
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -119,43 +119,6 @@ def test_correctness_model(
119119
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
120120

121121

122-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
123-
@pytest.mark.parametrize("randomize", (True, False))
124-
@pytest.mark.parametrize("head_dim", (4, 8))
125-
def test_correctness_attention_heads(type, randomize, head_dim):
126-
hidden_size = 64
127-
num_attention_heads = 8
128-
129-
attention = MockAttention(
130-
hidden_size=hidden_size,
131-
num_attention_heads=num_attention_heads,
132-
num_key_value_heads=head_dim,
133-
)
134-
135-
input = torch.rand(17, 5, hidden_size)
136-
true_output = attention(input)
137-
138-
config = TransformConfig(
139-
config_groups={
140-
"": TransformScheme(
141-
type=type,
142-
randomize=randomize,
143-
head_dim=head_dim,
144-
apply=[
145-
TransformArgs(targets="v_proj", location="weight_output"),
146-
TransformArgs(
147-
targets="o_proj", location="weight_input", inverse=True
148-
),
149-
],
150-
)
151-
}
152-
)
153-
apply_transform_config(attention, config)
154-
155-
output = attention(input)
156-
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
157-
158-
159122
@requires_gpu
160123
@requires_accelerate()
161124
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))

0 commit comments

Comments
 (0)