|
28 | 28 |
|
29 | 29 | @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
|
30 | 30 | @pytest.mark.parametrize("randomized", (True, False))
|
31 |
| -def test_correctness_linear(type, randomized): |
| 31 | +@pytest.mark.parametrize("head_dim", (None, 2, 4)) |
| 32 | +def test_correctness_linear(type, randomized, head_dim): |
32 | 33 | size = (4, 8)
|
33 | 34 | module = torch.nn.Linear(*size, bias=True)
|
34 |
| - scheme = TransformScheme(type=type, randomized=randomized) |
| 35 | + scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim) |
35 | 36 | factory = TransformFactory.from_scheme(scheme, name="")
|
36 | 37 |
|
37 | 38 | input_tfm = factory.create_transform(
|
@@ -90,43 +91,6 @@ def test_correctness_model_offload(type, randomized, model_apply):
|
90 | 91 | test_correctness_model(type, randomized, model_apply, offload=True)
|
91 | 92 |
|
92 | 93 |
|
93 |
| -@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) |
94 |
| -@pytest.mark.parametrize("randomized", (True, False)) |
95 |
| -@pytest.mark.parametrize("head_dim", (16, 32)) |
96 |
| -def test_correctness_heads(type, randomized, head_dim): |
97 |
| - hidden_size = 64 |
98 |
| - |
99 |
| - model = torch.nn.ModuleDict( |
100 |
| - { |
101 |
| - "v_proj": torch.nn.Linear(hidden_size, hidden_size, bias=False), |
102 |
| - "o_proj": torch.nn.Linear(hidden_size, hidden_size, bias=False), |
103 |
| - } |
104 |
| - ) |
105 |
| - |
106 |
| - input = torch.rand(17, 5, hidden_size) |
107 |
| - true_output = model.o_proj(model.v_proj(input)) |
108 |
| - |
109 |
| - config = TransformConfig( |
110 |
| - config_groups={ |
111 |
| - "": TransformScheme( |
112 |
| - type=type, |
113 |
| - randomized=randomized, |
114 |
| - head_dim=head_dim, |
115 |
| - apply=[ |
116 |
| - TransformArgs(targets="v_proj", location="weight_output"), |
117 |
| - TransformArgs( |
118 |
| - targets="o_proj", location="weight_input", inverse=True |
119 |
| - ), |
120 |
| - ], |
121 |
| - ) |
122 |
| - } |
123 |
| - ) |
124 |
| - apply_transform_config(model, config) |
125 |
| - |
126 |
| - output = model.o_proj(model.v_proj(input)) |
127 |
| - assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) |
128 |
| - |
129 |
| - |
130 | 94 | @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
|
131 | 95 | @pytest.mark.parametrize("randomized", (True, False))
|
132 | 96 | @pytest.mark.parametrize("head_dim", (4, 8))
|
|
0 commit comments