|
6 | 6 |
|
7 | 7 | import unittest |
8 | 8 |
|
9 | | -from typing import Tuple |
| 9 | +from typing import Callable, Tuple |
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 |
|
|
16 | 16 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
17 | 17 | from parameterized import parameterized |
18 | 18 |
|
19 | | -torch.manual_seed(1) |
20 | | - |
21 | 19 |
|
22 | 20 | class TestBMM(unittest.TestCase): |
23 | 21 | """Tests Batch MatMul""" |
24 | 22 |
|
25 | 23 | class BMM(torch.nn.Module): |
26 | | - test_parameters = [ |
27 | | - (torch.rand(2, 1, 1), torch.rand(2, 1, 1)), |
28 | | - (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), |
29 | | - (torch.ones(1, 55, 3), torch.ones(1, 3, 44)), |
30 | | - (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)), |
31 | | - (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)), |
| 24 | + test_data_generators = [ |
| 25 | + lambda: (torch.rand(2, 1, 1), torch.rand(2, 1, 1)), |
| 26 | + lambda: (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), |
| 27 | + lambda: (torch.ones(1, 55, 3), torch.ones(1, 3, 44)), |
| 28 | + lambda: (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)), |
| 29 | + lambda: (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)), |
32 | 30 | ] |
33 | 31 |
|
34 | 32 | def forward(self, x, y): |
35 | 33 | return torch.bmm(x, y) |
36 | 34 |
|
37 | 35 | class MatMul(torch.nn.Module): |
38 | | - test_parameters = [ |
39 | | - (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), |
40 | | - (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), |
| 36 | + test_data_generators = [ |
| 37 | + lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), |
| 38 | + lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), |
41 | 39 | ] |
42 | 40 |
|
43 | 41 | def forward(self, x, y): |
44 | 42 | return torch.matmul(x, y) |
45 | 43 |
|
46 | 44 | class BMMSingleInput(torch.nn.Module): |
47 | | - test_parameters = [ |
48 | | - (torch.rand(20, 3, 3),), |
49 | | - (torch.rand(2, 128, 128),), |
50 | | - (10000 * torch.randn(4, 25, 25),), |
51 | | - (5 + 5 * torch.randn(3, 64, 64),), |
| 45 | + test_data_generators = [ |
| 46 | + lambda: (torch.rand(20, 3, 3),), |
| 47 | + lambda: (torch.rand(2, 128, 128),), |
| 48 | + lambda: (10000 * torch.randn(4, 25, 25),), |
| 49 | + lambda: (5 + 5 * torch.randn(3, 64, 64),), |
52 | 50 | ] |
53 | 51 |
|
54 | 52 | def forward(self, x): |
@@ -120,67 +118,74 @@ def _test_bmm_ethosu_BI_pipeline( |
120 | 118 | if conftest.is_option_enabled("corstone_fvp"): |
121 | 119 | tester.run_method_and_compare_outputs(inputs=test_data, qtol=1) |
122 | 120 |
|
123 | | - @parameterized.expand(BMM.test_parameters) |
124 | | - def test_bmm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
125 | | - test_data = (operand1, operand2) |
| 121 | + @parameterized.expand(BMM.test_data_generators) |
| 122 | + def test_bmm_tosa_MI(self, test_data_generator: Callable[[], Tuple]): |
| 123 | + test_data = test_data_generator() |
126 | 124 | self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data) |
127 | 125 |
|
128 | | - @parameterized.expand(BMMSingleInput.test_parameters) |
129 | | - def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): |
130 | | - test_data = (operand1,) |
| 126 | + @parameterized.expand(BMMSingleInput.test_data_generators) |
| 127 | + def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]): |
| 128 | + test_data = test_data_generator() |
131 | 129 | self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) |
132 | 130 |
|
133 | | - @parameterized.expand(MatMul.test_parameters) |
134 | | - def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
135 | | - test_data = (operand1, operand2) |
| 131 | + @parameterized.expand(MatMul.test_data_generators) |
| 132 | + def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]): |
| 133 | + test_data = test_data_generator() |
136 | 134 | self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) |
137 | 135 |
|
138 | | - @parameterized.expand(MatMul.test_parameters) |
139 | | - def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
140 | | - test_data = (operand1, operand2) |
| 136 | + @parameterized.expand(MatMul.test_data_generators) |
| 137 | + @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
| 138 | + def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]): |
| 139 | + test_data = test_data_generator() |
141 | 140 | self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) |
142 | 141 |
|
143 | | - @parameterized.expand(BMM.test_parameters) |
144 | | - def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
145 | | - test_data = (operand1, operand2) |
| 142 | + @parameterized.expand(BMM.test_data_generators) |
| 143 | + @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
| 144 | + def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]): |
| 145 | + test_data = test_data_generator() |
146 | 146 | self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data) |
147 | 147 |
|
148 | | - @parameterized.expand(BMMSingleInput.test_parameters) |
149 | | - def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor): |
150 | | - test_data = (operand1,) |
| 148 | + @parameterized.expand(BMMSingleInput.test_data_generators) |
| 149 | + @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
| 150 | + def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]): |
| 151 | + test_data = test_data_generator() |
151 | 152 | self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data) |
152 | 153 |
|
153 | | - @parameterized.expand(BMM.test_parameters) |
| 154 | + @parameterized.expand(BMM.test_data_generators) |
154 | 155 | @pytest.mark.corstone_fvp |
155 | 156 | @unittest.expectedFailure |
156 | | - def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor): |
157 | | - test_data = (operand1, operand2) |
| 157 | + def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]): |
| 158 | + test_data = test_data_generator() |
158 | 159 | self._test_bmm_ethosu_BI_pipeline( |
159 | 160 | self.BMM(), common.get_u55_compile_spec(), test_data |
160 | 161 | ) |
161 | 162 |
|
162 | | - @parameterized.expand(BMM.test_parameters) |
| 163 | + @parameterized.expand(BMM.test_data_generators) |
163 | 164 | @pytest.mark.corstone_fvp |
164 | | - def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
165 | | - test_data = (operand1, operand2) |
| 165 | + @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
| 166 | + def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]): |
| 167 | + test_data = test_data_generator() |
166 | 168 | self._test_bmm_ethosu_BI_pipeline( |
167 | 169 | self.BMM(), common.get_u85_compile_spec(), test_data |
168 | 170 | ) |
169 | 171 |
|
170 | 172 | # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy |
171 | | - @parameterized.expand(BMMSingleInput.test_parameters) |
| 173 | + @parameterized.expand(BMMSingleInput.test_data_generators) |
172 | 174 | @pytest.mark.corstone_fvp |
173 | 175 | @unittest.expectedFailure |
174 | | - def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor): |
175 | | - test_data = (operand1,) |
| 176 | + def test_bmm_single_input_u55_BI_xfails( |
| 177 | + self, test_data_generator: Callable[[], Tuple] |
| 178 | + ): |
| 179 | + test_data = test_data_generator() |
176 | 180 | self._test_bmm_ethosu_BI_pipeline( |
177 | 181 | self.BMMSingleInput(), common.get_u55_compile_spec(), test_data |
178 | 182 | ) |
179 | 183 |
|
180 | | - @parameterized.expand(BMMSingleInput.test_parameters) |
| 184 | + @parameterized.expand(BMMSingleInput.test_data_generators) |
181 | 185 | @pytest.mark.corstone_fvp |
182 | | - def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor): |
183 | | - test_data = (operand1,) |
| 186 | + @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) |
| 187 | + def test_bmm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]): |
| 188 | + test_data = test_data_generator() |
184 | 189 | self._test_bmm_ethosu_BI_pipeline( |
185 | 190 | self.BMMSingleInput(), common.get_u85_compile_spec(), test_data |
186 | 191 | ) |
0 commit comments