Skip to content

Commit 732aff9

Browse files
authored
Disable mm + add -> addmm fusion if added tensor rank >2
Differential Revision: D80906791 Pull Request resolved: #13632
1 parent 7dab7c1 commit 732aff9

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule):
7272
fuse it with mm.
7373
"""
7474
graph = graph_module.graph
75-
for node in graph.nodes:
75+
for node in graph.find_nodes(
76+
op="call_function", target=exir_ops.edge.aten.mm.default
77+
):
7678
# We want to discover a chain of mm -> add, or mm -> view -> add.
7779
# Only proceed if the current node is an mm node, and has only one
7880
# user/successor.
79-
if node.target != exir_ops.edge.aten.mm.default or len(node.users) != 1:
81+
if len(node.users) != 1:
8082
continue
8183

8284
# Our addmm implementation computes (mat1 * mat2 + bias). So the
@@ -128,6 +130,7 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule):
128130
mm_arg_shape is None
129131
or bias_arg_shape is None
130132
or not broadcastable(mm_arg_shape, bias_arg_shape)
133+
or len(bias_arg_shape) > 2
131134
):
132135
continue
133136

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,29 @@ def check_op_counts(
4040
self.assertTrue(op_counts_match(graph_module, expected_op_counts))
4141

4242

43-
class TestFusionPasses(TestFusionPassesBase):
43+
class TestFuseMMWithAddPass(TestFusionPassesBase):
44+
def test_no_fuse_for_3d_bias(self) -> None:
45+
builder = GraphBuilder()
46+
x = builder.placeholder("x", torch.randn(4, 3, dtype=torch.float32))
47+
y = builder.placeholder("y", torch.randn(3, 5, dtype=torch.float32))
48+
z = builder.placeholder("z", torch.randn(1, 4, 5, dtype=torch.float32))
49+
mm = builder.call_operator(
50+
op=exir_ops.edge.aten.mm.default,
51+
args=(x, y),
52+
)
53+
output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
54+
builder.output([output])
55+
original_graph = builder.get_graph_module()
56+
57+
p = FuseMMWithAdd()
58+
converted_graph = cast(PassResult, p(original_graph)).graph_module
59+
converted_graph.graph.eliminate_dead_code()
60+
self.assertEqual(
61+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
62+
)
63+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
64+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1)
65+
4466
def test_fuse_mm_with_add(self) -> None:
4567
builder = GraphBuilder()
4668
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
@@ -176,6 +198,8 @@ def test_keep_mm_add_with_multiple_users(self) -> None:
176198
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
177199
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)
178200

201+
202+
class TestFusionPasses(TestFusionPassesBase):
179203
def test_permute_transpose_fusion(self) -> None:
180204
builder = GraphBuilder()
181205
x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))

0 commit comments

Comments
 (0)