From fe37ec4b2337c6011af31b100b66672786073fca Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 29 Sep 2025 14:49:57 +0200 Subject: [PATCH 1/2] Extend FuseViewCopyTransform to fuse more views Extends the pass to find chains of unary elementwise ops and fuse all views in each chain. This gives the same result since the shape does not matter for elementwise ops. This change allows to fuse patterns like view -> clone -> view. Signed-off-by: Adrian Lundell Change-Id: I41afdbebf27124fa474e02180725ff28660ffef1 --- .../arm/test/passes/test_fuse_view_copy.py | 82 +++++++++++++++++++ backends/transforms/fuse_view_copy.py | 60 +++++++++++--- 2 files changed, 130 insertions(+), 12 deletions(-) create mode 100644 backends/arm/test/passes/test_fuse_view_copy.py diff --git a/backends/arm/test/passes/test_fuse_view_copy.py b/backends/arm/test/passes/test_fuse_view_copy.py new file mode 100644 index 00000000000..7bf931349b6 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_view_copy.py @@ -0,0 +1,82 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform + + +class FuseSequentialViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.view((1, 2, 3, 4)).view((2, 3, 4, 1)).view((2, 3, 4)) + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 1, + } + + +class FuseSequentialWithNoopsViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + return ( + x.view((1, 2, 3, 4)) + .clone() + .view((2, 3, 4, 1)) + .to(dtype=torch.int32) + .view((2, 3, 4)) + .abs() + .reciprocal() + .sqrt() + .view((12, 2)) + ) + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 4, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 1, + } + + +class DontFuseBranchingViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + x = x.view((1, 2, 3, 4)) + x1 = x.abs().view((2, 3, 4, 1)) + x2 = x.ceil().view((2, 3, 4, 1)) + return x1 + x2 + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + + +tests = { + "fuse_sequential_views": FuseSequentialViews(), + "fuse_sequential_with_noops_views": FuseSequentialWithNoopsViews(), + "dont_fuse_branching_views": DontFuseBranchingViews(), +} + + +@common.parametrize("model", tests) +def test_fuse_view_copy(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseViewCopyTransform], + ) + pipeline.run() diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index 1972513d2ef..75467df4aa8 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -14,9 +14,37 @@ from executorch.exir.pass_base import ExportPass, PassResult +UNARY_ELEMENTWISE_OPS = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.round.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.sign.default, + exir_ops.edge.aten.reciprocal.default, +] + + def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: """ - Find chains of view_copy nodes and merge them into one view_copy node. + Find chains of view_copy nodes and unary elementwise ops and set all + view_copy nodes to have the final shape. The views will then be removed + by the remove_noop_view_copy call. + Only merges view_copy nodes that are not used by any other nodes. """ ops = exir_ops.edge @@ -24,21 +52,25 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool] modified = False for node in graph.nodes: if node.op == "call_function" and node.target == view_op: - # find ending view_copy node in chain + # Find a chain of unary elementwise ops and save all view_copy nodes end_node = node + view_ops = [node] while ( end_node.op == "call_function" - and end_node.target == view_op + and end_node.target in UNARY_ELEMENTWISE_OPS and len(end_node.users) == 1 - and list(end_node.users)[0].target == view_op + and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS ): end_node = list(end_node.users)[0] - # we can swap the first node's shape arg with the last node's shape arg - if node != end_node: - with graph.inserting_after(node): - new_args = (node.args[0], end_node.args[1]) + if end_node.target == view_op: + view_ops.append(end_node) + + # Set all view_copy nodes to have the final shape + if len(view_ops) > 1: + final_shape = view_ops[-1].args[1] + for node in view_ops: + new_args = (node.args[0], final_shape) node.args = new_args - end_node.replace_all_uses_with(node) modified = True graph.eliminate_dead_code() @@ -67,10 +99,14 @@ class FuseViewCopyTransform(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) - graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph) - modified = merge_modified or noop_modified + graph_module.graph, modified = merge_view_copy_chains(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module + + graph_module.graph, modified = remove_noop_view_copy(graph_module.graph) + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) From 829d167214099cd54b5d00cb47a225c24045e8e8 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 8 Oct 2025 16:48:08 +0200 Subject: [PATCH 2/2] Add more unary ops Signed-off-by: Adrian Lundell --- backends/transforms/fuse_view_copy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index 75467df4aa8..b7c52f95fa3 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -36,6 +36,10 @@ exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.sign.default, exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, ]