|
12 | 12 |
|
13 | 13 | import executorch.backends.cadence.aot.ops_registrations # noqa |
14 | 14 | import torch |
15 | | -from executorch.backends.cadence.aot import compiler |
16 | 15 | from executorch.backends.cadence.aot.fuse_ops import ( |
17 | 16 | FuseCascadedTransposeOrPermuteOps, |
18 | 17 | FuseCascadedViewOps, |
|
30 | 29 | from executorch.exir.dialects._ops import ops as exir_ops |
31 | 30 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
32 | 31 | from executorch.exir.pass_base import PassResult, ProxyValue |
33 | | -from torch import nn |
34 | 32 |
|
35 | 33 |
|
36 | 34 | class TestFusionPassesBase(unittest.TestCase): |
@@ -178,43 +176,6 @@ def test_keep_mm_add_with_multiple_users(self) -> None: |
178 | 176 | self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) |
179 | 177 | self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3) |
180 | 178 |
|
181 | | - # TODO(matthiascremon) -> None: enable that pass with new flow |
182 | | - @torch.no_grad() |
183 | | - @unittest.expectedFailure |
184 | | - def test_legacy_conv_bn_fusion(self) -> None: |
185 | | - class ModelConvBN(torch.nn.Module): |
186 | | - def __init__( |
187 | | - self, in_features: int, out_features: int, kernel_size: int |
188 | | - ) -> None: |
189 | | - super().__init__() |
190 | | - self.conv1d = nn.Conv1d(in_features, out_features, kernel_size) |
191 | | - self.bn = nn.BatchNorm1d(out_features) |
192 | | - |
193 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
194 | | - y = self.conv1d(x) |
195 | | - return self.bn(y) |
196 | | - |
197 | | - model = ModelConvBN(64, 1, 2) |
198 | | - x = torch.randn(1, 64, 4) |
199 | | - |
200 | | - graph_module = ( |
201 | | - compiler.export_to_executorch_gen_etrecord(model.eval(), (x,)) |
202 | | - .exported_program() |
203 | | - .graph_module |
204 | | - ) |
205 | | - # Assert that after running the fusion passes, batchnorm was fused with conv1d |
206 | | - self.assertEqual( |
207 | | - count_node(graph_module, torch.ops.aten.linear.out) |
208 | | - + count_node(graph_module, torch.ops.cadence.convolution.out), |
209 | | - 1, |
210 | | - ) |
211 | | - self.assertEqual( |
212 | | - count_node( |
213 | | - graph_module, torch.ops.aten._native_batch_norm_legit_no_training.out |
214 | | - ), |
215 | | - 0, |
216 | | - ) |
217 | | - |
218 | 179 | def test_permute_transpose_fusion(self) -> None: |
219 | 180 | builder = GraphBuilder() |
220 | 181 | x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) |
|
0 commit comments