Skip to content

Commit 17a87a5

Browse files
chunnienccopybara-github
authored andcommitted
fix fx pattern matcher for bilinear upsample
PiperOrigin-RevId: 713313271
1 parent 6d285bb commit 17a87a5

File tree

4 files changed

+63
-17
lines changed

4 files changed

+63
-17
lines changed

ai_edge_torch/hlfb/mark_pattern/__init__.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import uuid
1818

1919
from ai_edge_torch import lowertools
20-
from ai_edge_torch.hlfb.mark_pattern import passes
20+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
2121
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
2222
import torch
2323

@@ -87,7 +87,7 @@ def mark_pattern(
8787
m.meta["ORIGINAL_NODE"] = n
8888

8989
# Sanitize graph_module to match in the same way as pattern's graph_module.
90-
graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
90+
graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
9191

9292
match_with_attrs = pattern.match(graph_module_to_match)
9393

@@ -111,13 +111,25 @@ def mark_pattern(
111111
is_input=True,
112112
)
113113

114-
# Only replace input by the marker node for those nodes used in the pattern.
114+
# Only replace input by the marker node for those nodes used in the
115+
# pattern.
115116
in_pattern_nodes = set(match.nodes_map.values())
116117
for user in input_node.users.keys():
117-
if user in in_pattern_nodes:
118-
user.meta["ORIGINAL_NODE"].replace_input_with(
119-
input_node.meta["ORIGINAL_NODE"], new_input_node
120-
)
118+
if user not in in_pattern_nodes:
119+
continue
120+
121+
user.meta["ORIGINAL_NODE"].replace_input_with(
122+
input_node.meta["ORIGINAL_NODE"], new_input_node
123+
)
124+
# Pattern matching graph sanitization may remove clone ops, which means
125+
# the user's input in the original graph may be a clone op. When
126+
# replacing the input with the marker node, we need to further try
127+
# replacing the input of the clone op that connects to the user.
128+
for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
129+
if fx_utils.is_clone_op(original_user_input):
130+
original_user_input.replace_input_with(
131+
input_node.meta["ORIGINAL_NODE"], new_input_node
132+
)
121133

122134
for i, pattern_output_node in enumerate(pattern.output_nodes):
123135
output_node = match.nodes_map[pattern_output_node]

ai_edge_torch/hlfb/mark_pattern/passes.py renamed to ai_edge_torch/hlfb/mark_pattern/fx_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Passes to clean up the model graph for pattern matching."""
15+
"""FX graph utilities for pattern matching clean ups."""
1616

1717
import torch
1818

1919

20+
def is_clone_op(node: torch.fx.Node) -> bool:
21+
"""Checks if the node is a clone op."""
22+
return (
23+
node.op == "call_function" and node.target == torch.ops.aten.clone.default
24+
)
25+
26+
2027
def remove_clone_ops(gm: torch.fx.GraphModule):
2128
"""Removes clone ops from the graph.
2229
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
3239
The graph module with clone ops removed.
3340
"""
3441
for node in gm.graph.nodes:
35-
if node.op == "call_function" and node.name.startswith("clone"):
42+
if is_clone_op(node):
3643
node.replace_all_uses_with(node.args[0])
3744
gm.graph.erase_node(node)
3845

ai_edge_torch/hlfb/mark_pattern/pattern.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
from typing import Any, Callable, Optional, Union
1919

2020
from ai_edge_torch import fx_pass_base
21-
from ai_edge_torch.hlfb.mark_pattern import passes
21+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
2222
import torch
23-
from torch.export.graph_signature import TensorArgument
24-
from torch.fx import Graph
25-
from torch.fx import GraphModule
26-
from torch.fx.passes.utils.matcher_utils import InternalMatch
27-
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
23+
24+
Graph = torch.fx.Graph
25+
GraphModule = torch.fx.GraphModule
26+
TensorArgument = torch.export.graph_signature.TensorArgument
27+
InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
28+
SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
2829

2930

3031
def _are_equal(x: Any, y: Any) -> bool:
@@ -219,8 +220,8 @@ def forward(self, *args, **kwargs):
219220
# Sanitize graph_module for more precise pattern matching.
220221
# The graph_module to match against this pattern should apply equivalent
221222
# sanitization.
222-
self.graph_module = passes.remove_clone_ops(self.graph_module)
223-
self.graph_module = passes.remove_dangling_args(self.graph_module)
223+
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
224+
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
224225

225226
# Builds list of ordered input and output nodes.
226227
self.graph_nodes_map = {}

ai_edge_torch/hlfb/test/test_mark_pattern.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,32 @@ def forward(self, x):
5858
{"stablehlo.custom_call @mark_tensor": 6},
5959
)
6060

61+
def test_mark_pattern_with_clone_inputs(self):
62+
63+
class TestModel(torch.nn.Module):
64+
65+
def forward(self, x):
66+
return torch.ops.aten.clone.default(x * x) + x
67+
68+
pattern = pattern_module.Pattern(
69+
"test.add",
70+
lambda a, b: a + b,
71+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
72+
)
73+
74+
model = TestModel().eval()
75+
args = (torch.rand(20, 20),)
76+
exported_program = torch.export.export(model, args)
77+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
78+
mlir = _export_stablehlo_mlir(exported_program)
79+
80+
lowertools.assert_string_count(
81+
self,
82+
mlir,
83+
{'stablehlo.composite "test.add"': 1},
84+
{"stablehlo.custom_call @mark_tensor": 3},
85+
)
86+
6187
def test_mark_pattern_with_attr_builder(self):
6288
class TestModel(torch.nn.Module):
6389

0 commit comments

Comments
 (0)