Skip to content

Commit 7334234

Browse files
committed
simplify tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent ac1eece commit 7334234

File tree

1 file changed

+3
-39
lines changed

1 file changed

+3
-39
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828

2929
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
3030
@pytest.mark.parametrize("randomized", (True, False))
31-
def test_correctness_linear(type, randomized):
31+
@pytest.mark.parametrize("head_dim", (None, 2, 4))
32+
def test_correctness_linear(type, randomized, head_dim):
3233
size = (4, 8)
3334
module = torch.nn.Linear(*size, bias=True)
34-
scheme = TransformScheme(type=type, randomized=randomized)
35+
scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim)
3536
factory = TransformFactory.from_scheme(scheme, name="")
3637

3738
input_tfm = factory.create_transform(
@@ -90,43 +91,6 @@ def test_correctness_model_offload(type, randomized, model_apply):
9091
test_correctness_model(type, randomized, model_apply, offload=True)
9192

9293

93-
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
94-
@pytest.mark.parametrize("randomized", (True, False))
95-
@pytest.mark.parametrize("head_dim", (16, 32))
96-
def test_correctness_heads(type, randomized, head_dim):
97-
hidden_size = 64
98-
99-
model = torch.nn.ModuleDict(
100-
{
101-
"v_proj": torch.nn.Linear(hidden_size, hidden_size, bias=False),
102-
"o_proj": torch.nn.Linear(hidden_size, hidden_size, bias=False),
103-
}
104-
)
105-
106-
input = torch.rand(17, 5, hidden_size)
107-
true_output = model.o_proj(model.v_proj(input))
108-
109-
config = TransformConfig(
110-
config_groups={
111-
"": TransformScheme(
112-
type=type,
113-
randomized=randomized,
114-
head_dim=head_dim,
115-
apply=[
116-
TransformArgs(targets="v_proj", location="weight_output"),
117-
TransformArgs(
118-
targets="o_proj", location="weight_input", inverse=True
119-
),
120-
],
121-
)
122-
}
123-
)
124-
apply_transform_config(model, config)
125-
126-
output = model.o_proj(model.v_proj(input))
127-
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
128-
129-
13094
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
13195
@pytest.mark.parametrize("randomized", (True, False))
13296
@pytest.mark.parametrize("head_dim", (4, 8))

0 commit comments

Comments
 (0)