11# Copyright (c) Meta Platforms, Inc. and affiliates.
22# All rights reserved.
3+ # Copyright 2025 Arm Limited and/or its affiliates.
34#
45# This source code is licensed under the BSD-style license found in the
56# LICENSE file in the root directory of this source tree.
1112from executorch .exir .pass_base import ExportPass , PassResult
1213
1314
14- def merge_view_copy_chains (graph : torch .fx .Graph ) -> torch .fx .Graph :
15+ def merge_view_copy_chains (graph : torch .fx .Graph ) -> tuple [ torch .fx .Graph , bool ] :
1516 """
1617 Find chains of view_copy nodes and merge them into one view_copy node.
1718 Only merges view_copy nodes that are not used by any other nodes.
1819 """
1920 ops = exir_ops .edge
2021 view_op = ops .aten .view_copy .default
22+ modified = False
2123 for node in graph .nodes :
2224 if node .op == "call_function" and node .target == view_op :
2325 # find ending view_copy node in chain
@@ -35,29 +37,36 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
3537 new_args = (node .args [0 ], end_node .args [1 ])
3638 node .args = new_args
3739 end_node .replace_all_uses_with (node )
40+ modified = True
3841
3942 graph .eliminate_dead_code ()
40- return graph
43+ return graph , modified
4144
4245
43- def remove_noop_view_copy (graph : torch .fx .Graph ) -> torch .fx .Graph :
46+ def remove_noop_view_copy (graph : torch .fx .Graph ) -> tuple [ torch .fx .Graph , bool ] :
4447 """
4548 Remove view_copy nodes that are no-ops.
4649 """
4750 ops = exir_ops .edge
4851 view_op = ops .aten .view_copy .default
52+ modified = False
4953 for node in graph .nodes :
5054 if node .op == "call_function" and node .target == view_op :
5155 input_shape = list (node .args [0 ].meta ["val" ].shape )
5256 target_shape = node .args [1 ]
5357 if input_shape == target_shape :
5458 node .replace_all_uses_with (node .args [0 ])
59+ modified = True
5560 graph .eliminate_dead_code ()
56- return graph
61+ return graph , modified
5762
5863
5964class FuseViewCopyTransform (ExportPass ):
6065 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
61- graph_module .graph = merge_view_copy_chains (graph_module .graph )
62- graph_module .graph = remove_noop_view_copy (graph_module .graph )
63- return PassResult (graph_module , True )
66+ graph_module .graph , merge_modified = merge_view_copy_chains (graph_module .graph )
67+ graph_module .graph , noop_modified = remove_noop_view_copy (graph_module .graph )
68+ modified = merge_modified or noop_modified
69+ if modified :
70+ graph_module .recompile ()
71+ graph_module = super ().call (graph_module ).graph_module
72+ return PassResult (graph_module , modified )
0 commit comments