Skip to content

Commit 526c818

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Remove legacy tests
Summary: If any of those ops need to be added back in the PT2 flow, we can write new tests using graph builder. Reviewed By: hsharma35 Differential Revision: D78417828
1 parent 80da097 commit 526c818

File tree

1 file changed

+0
-37
lines changed

1 file changed

+0
-37
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -178,43 +178,6 @@ def test_keep_mm_add_with_multiple_users(self) -> None:
178178
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
179179
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)
180180

181-
# TODO(matthiascremon) -> None: enable that pass with new flow
182-
@torch.no_grad()
183-
@unittest.expectedFailure
184-
def test_legacy_conv_bn_fusion(self) -> None:
185-
class ModelConvBN(torch.nn.Module):
186-
def __init__(
187-
self, in_features: int, out_features: int, kernel_size: int
188-
) -> None:
189-
super().__init__()
190-
self.conv1d = nn.Conv1d(in_features, out_features, kernel_size)
191-
self.bn = nn.BatchNorm1d(out_features)
192-
193-
def forward(self, x: torch.Tensor) -> torch.Tensor:
194-
y = self.conv1d(x)
195-
return self.bn(y)
196-
197-
model = ModelConvBN(64, 1, 2)
198-
x = torch.randn(1, 64, 4)
199-
200-
graph_module = (
201-
compiler.export_to_executorch_gen_etrecord(model.eval(), (x,))
202-
.exported_program()
203-
.graph_module
204-
)
205-
# Assert that after running the fusion passes, batchnorm was fused with conv1d
206-
self.assertEqual(
207-
count_node(graph_module, torch.ops.aten.linear.out)
208-
+ count_node(graph_module, torch.ops.cadence.convolution.out),
209-
1,
210-
)
211-
self.assertEqual(
212-
count_node(
213-
graph_module, torch.ops.aten._native_batch_norm_legit_no_training.out
214-
),
215-
0,
216-
)
217-
218181
def test_permute_transpose_fusion(self) -> None:
219182
builder = GraphBuilder()
220183
x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))

0 commit comments

Comments
 (0)