|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 |
|
| 10 | +import logging |
10 | 11 | import unittest |
11 | 12 | from typing import cast, Final, List, Tuple |
12 | 13 |
|
@@ -40,7 +41,30 @@ def check_op_counts( |
40 | 41 | self.assertTrue(op_counts_match(graph_module, expected_op_counts)) |
41 | 42 |
|
42 | 43 |
|
43 | | -class TestFusionPasses(TestFusionPassesBase): |
| 44 | +class TestFuseMMWithAddPass(TestFusionPassesBase): |
| 45 | + def test_no_fuse_for_3d_bias(self) -> None: |
| 46 | + builder = GraphBuilder() |
| 47 | + x = builder.placeholder("x", torch.randn(4, 3, dtype=torch.float32)) |
| 48 | + y = builder.placeholder("y", torch.randn(3, 5, dtype=torch.float32)) |
| 49 | + z = builder.placeholder("z", torch.randn(1, 4, 5, dtype=torch.float32)) |
| 50 | + mm = builder.call_operator( |
| 51 | + op=exir_ops.edge.aten.mm.default, |
| 52 | + args=(x, y), |
| 53 | + ) |
| 54 | + output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) |
| 55 | + builder.output([output]) |
| 56 | + original_graph = builder.get_graph_module() |
| 57 | + logging.error(original_graph.print_readable(print_output=False)) |
| 58 | + |
| 59 | + p = FuseMMWithAdd() |
| 60 | + converted_graph = cast(PassResult, p(original_graph)).graph_module |
| 61 | + converted_graph.graph.eliminate_dead_code() |
| 62 | + self.assertEqual( |
| 63 | + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0 |
| 64 | + ) |
| 65 | + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) |
| 66 | + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1) |
| 67 | + |
44 | 68 | def test_fuse_mm_with_add(self) -> None: |
45 | 69 | builder = GraphBuilder() |
46 | 70 | x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) |
@@ -176,6 +200,9 @@ def test_keep_mm_add_with_multiple_users(self) -> None: |
176 | 200 | self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) |
177 | 201 | self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3) |
178 | 202 |
|
| 203 | + |
| 204 | +class TestFusionPasses(TestFusionPassesBase): |
| 205 | + |
179 | 206 | def test_permute_transpose_fusion(self) -> None: |
180 | 207 | builder = GraphBuilder() |
181 | 208 | x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) |
|
0 commit comments