Skip to content

Commit fe37ec4

Browse files
committed
Extend FuseViewCopyTransform to fuse more views
Extends the pass to find chains of unary elementwise ops and fuse all views in each chain. This gives the same result since the shape does not matter for elementwise ops. This change allows to fuse patterns like view -> clone -> view. Signed-off-by: Adrian Lundell <[email protected]> Change-Id: I41afdbebf27124fa474e02180725ff28660ffef1
1 parent f24351a commit fe37ec4

File tree

2 files changed

+130
-12
lines changed

2 files changed

+130
-12
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
from executorch.backends.arm.test import common
9+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
10+
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
11+
12+
13+
class FuseSequentialViews(torch.nn.Module):
14+
def forward(self, x: torch.Tensor):
15+
return x.view((1, 2, 3, 4)).view((2, 3, 4, 1)).view((2, 3, 4))
16+
17+
data = (torch.randn(2, 3, 1, 4),)
18+
ops_before_pass = {
19+
"executorch_exir_dialects_edge__ops_aten_view_copy": 3,
20+
}
21+
ops_after_pass = {
22+
"executorch_exir_dialects_edge__ops_aten_view_copy": 1,
23+
}
24+
25+
26+
class FuseSequentialWithNoopsViews(torch.nn.Module):
27+
def forward(self, x: torch.Tensor):
28+
return (
29+
x.view((1, 2, 3, 4))
30+
.clone()
31+
.view((2, 3, 4, 1))
32+
.to(dtype=torch.int32)
33+
.view((2, 3, 4))
34+
.abs()
35+
.reciprocal()
36+
.sqrt()
37+
.view((12, 2))
38+
)
39+
40+
data = (torch.randn(2, 3, 1, 4),)
41+
ops_before_pass = {
42+
"executorch_exir_dialects_edge__ops_aten_view_copy": 4,
43+
}
44+
ops_after_pass = {
45+
"executorch_exir_dialects_edge__ops_aten_view_copy": 1,
46+
}
47+
48+
49+
class DontFuseBranchingViews(torch.nn.Module):
50+
def forward(self, x: torch.Tensor):
51+
x = x.view((1, 2, 3, 4))
52+
x1 = x.abs().view((2, 3, 4, 1))
53+
x2 = x.ceil().view((2, 3, 4, 1))
54+
return x1 + x2
55+
56+
data = (torch.randn(2, 3, 1, 4),)
57+
ops_before_pass = {
58+
"executorch_exir_dialects_edge__ops_aten_view_copy": 3,
59+
}
60+
ops_after_pass = {
61+
"executorch_exir_dialects_edge__ops_aten_view_copy": 3,
62+
}
63+
64+
65+
tests = {
66+
"fuse_sequential_views": FuseSequentialViews(),
67+
"fuse_sequential_with_noops_views": FuseSequentialWithNoopsViews(),
68+
"dont_fuse_branching_views": DontFuseBranchingViews(),
69+
}
70+
71+
72+
@common.parametrize("model", tests)
73+
def test_fuse_view_copy(model):
74+
pipeline = PassPipeline(
75+
model,
76+
model.data,
77+
quantize=False,
78+
ops_before_pass=model.ops_before_pass,
79+
ops_after_pass=model.ops_after_pass,
80+
pass_list=[FuseViewCopyTransform],
81+
)
82+
pipeline.run()

backends/transforms/fuse_view_copy.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,63 @@
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

1616

17+
UNARY_ELEMENTWISE_OPS = [
18+
exir_ops.edge.aten.view_copy.default,
19+
exir_ops.edge.aten.alias_copy.default,
20+
exir_ops.edge.aten.clone.default,
21+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
22+
exir_ops.edge.aten._to_copy.default,
23+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
24+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
25+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
26+
exir_ops.edge.aten.abs.default,
27+
exir_ops.edge.aten.clamp.default,
28+
exir_ops.edge.aten.ceil.default,
29+
exir_ops.edge.aten.floor.default,
30+
exir_ops.edge.aten.neg.default,
31+
exir_ops.edge.aten.relu.default,
32+
exir_ops.edge.aten.round.default,
33+
exir_ops.edge.aten.sigmoid.default,
34+
exir_ops.edge.aten.silu.default,
35+
exir_ops.edge.aten.sqrt.default,
36+
exir_ops.edge.aten.tanh.default,
37+
exir_ops.edge.aten.sign.default,
38+
exir_ops.edge.aten.reciprocal.default,
39+
]
40+
41+
1742
def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
1843
"""
19-
Find chains of view_copy nodes and merge them into one view_copy node.
44+
Find chains of view_copy nodes and unary elementwise ops and set all
45+
view_copy nodes to have the final shape. The views will then be removed
46+
by the remove_noop_view_copy call.
47+
2048
Only merges view_copy nodes that are not used by any other nodes.
2149
"""
2250
ops = exir_ops.edge
2351
view_op = ops.aten.view_copy.default
2452
modified = False
2553
for node in graph.nodes:
2654
if node.op == "call_function" and node.target == view_op:
27-
# find ending view_copy node in chain
55+
# Find a chain of unary elementwise ops and save all view_copy nodes
2856
end_node = node
57+
view_ops = [node]
2958
while (
3059
end_node.op == "call_function"
31-
and end_node.target == view_op
60+
and end_node.target in UNARY_ELEMENTWISE_OPS
3261
and len(end_node.users) == 1
33-
and list(end_node.users)[0].target == view_op
62+
and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS
3463
):
3564
end_node = list(end_node.users)[0]
36-
# we can swap the first node's shape arg with the last node's shape arg
37-
if node != end_node:
38-
with graph.inserting_after(node):
39-
new_args = (node.args[0], end_node.args[1])
65+
if end_node.target == view_op:
66+
view_ops.append(end_node)
67+
68+
# Set all view_copy nodes to have the final shape
69+
if len(view_ops) > 1:
70+
final_shape = view_ops[-1].args[1]
71+
for node in view_ops:
72+
new_args = (node.args[0], final_shape)
4073
node.args = new_args
41-
end_node.replace_all_uses_with(node)
4274
modified = True
4375

4476
graph.eliminate_dead_code()
@@ -67,10 +99,14 @@ class FuseViewCopyTransform(ExportPass):
6799
_passes_required_after: Set[Type[ExportPass]] = set()
68100

69101
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
70-
graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph)
71-
graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph)
72-
modified = merge_modified or noop_modified
102+
graph_module.graph, modified = merge_view_copy_chains(graph_module.graph)
73103
if modified:
74104
graph_module.recompile()
75105
graph_module = super().call(graph_module).graph_module
106+
107+
graph_module.graph, modified = remove_noop_view_copy(graph_module.graph)
108+
if modified:
109+
graph_module.recompile()
110+
graph_module = super().call(graph_module).graph_module
111+
76112
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)