Skip to content

Commit 6298f99

Browse files
hsharma35facebook-github-bot
authored andcommitted
Disable mm + add -> addmm fusion if added tensor rank >2
Summary: Addmm meta kernel allows the added tensor rank to be >2 but the implementation does not. This diff disables fusion of mm + add in such cases. Differential Revision: D80906791
1 parent def4c64 commit 6298f99

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule):
7272
fuse it with mm.
7373
"""
7474
graph = graph_module.graph
75-
for node in graph.nodes:
75+
for node in graph.find_nodes(
76+
op="call_function", target=exir_ops.edge.aten.mm.default
77+
):
7678
# We want to discover a chain of mm -> add, or mm -> view -> add.
7779
# Only proceed if the current node is an mm node, and has only one
7880
# user/successor.
79-
if node.target != exir_ops.edge.aten.mm.default or len(node.users) != 1:
81+
if len(node.users) != 1:
8082
continue
8183

8284
# Our addmm implementation computes (mat1 * mat2 + bias). So the
@@ -128,6 +130,7 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule):
128130
mm_arg_shape is None
129131
or bias_arg_shape is None
130132
or not broadcastable(mm_arg_shape, bias_arg_shape)
133+
or len(bias_arg_shape) > 2
131134
):
132135
continue
133136

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99

10+
import logging
1011
import unittest
1112
from typing import cast, Final, List, Tuple
1213

@@ -40,7 +41,30 @@ def check_op_counts(
4041
self.assertTrue(op_counts_match(graph_module, expected_op_counts))
4142

4243

43-
class TestFusionPasses(TestFusionPassesBase):
44+
class TestFuseMMWithAddPass(TestFusionPassesBase):
45+
def test_no_fuse_for_3d_bias(self) -> None:
46+
builder = GraphBuilder()
47+
x = builder.placeholder("x", torch.randn(4, 3, dtype=torch.float32))
48+
y = builder.placeholder("y", torch.randn(3, 5, dtype=torch.float32))
49+
z = builder.placeholder("z", torch.randn(1, 4, 5, dtype=torch.float32))
50+
mm = builder.call_operator(
51+
op=exir_ops.edge.aten.mm.default,
52+
args=(x, y),
53+
)
54+
output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
55+
builder.output([output])
56+
original_graph = builder.get_graph_module()
57+
logging.error(original_graph.print_readable(print_output=False))
58+
59+
p = FuseMMWithAdd()
60+
converted_graph = cast(PassResult, p(original_graph)).graph_module
61+
converted_graph.graph.eliminate_dead_code()
62+
self.assertEqual(
63+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
64+
)
65+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
66+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1)
67+
4468
def test_fuse_mm_with_add(self) -> None:
4569
builder = GraphBuilder()
4670
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
@@ -176,6 +200,9 @@ def test_keep_mm_add_with_multiple_users(self) -> None:
176200
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
177201
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)
178202

203+
204+
class TestFusionPasses(TestFusionPassesBase):
205+
179206
def test_permute_transpose_fusion(self) -> None:
180207
builder = GraphBuilder()
181208
x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))

0 commit comments

Comments
 (0)