|  | 
| 14 | 14 | from executorch.exir.pass_base import ExportPass, PassResult | 
| 15 | 15 | 
 | 
| 16 | 16 | 
 | 
|  | 17 | +UNARY_ELEMENTWISE_OPS = [ | 
|  | 18 | +    exir_ops.edge.aten.view_copy.default, | 
|  | 19 | +    exir_ops.edge.aten.alias_copy.default, | 
|  | 20 | +    exir_ops.edge.aten.clone.default, | 
|  | 21 | +    exir_ops.edge.dim_order_ops._clone_dim_order.default, | 
|  | 22 | +    exir_ops.edge.aten._to_copy.default, | 
|  | 23 | +    exir_ops.edge.dim_order_ops._to_dim_order_copy.default, | 
|  | 24 | +    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, | 
|  | 25 | +    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, | 
|  | 26 | +    exir_ops.edge.aten.abs.default, | 
|  | 27 | +    exir_ops.edge.aten.clamp.default, | 
|  | 28 | +    exir_ops.edge.aten.ceil.default, | 
|  | 29 | +    exir_ops.edge.aten.floor.default, | 
|  | 30 | +    exir_ops.edge.aten.neg.default, | 
|  | 31 | +    exir_ops.edge.aten.relu.default, | 
|  | 32 | +    exir_ops.edge.aten.round.default, | 
|  | 33 | +    exir_ops.edge.aten.sigmoid.default, | 
|  | 34 | +    exir_ops.edge.aten.silu.default, | 
|  | 35 | +    exir_ops.edge.aten.sqrt.default, | 
|  | 36 | +    exir_ops.edge.aten.tanh.default, | 
|  | 37 | +    exir_ops.edge.aten.sign.default, | 
|  | 38 | +    exir_ops.edge.aten.reciprocal.default, | 
|  | 39 | +] | 
|  | 40 | + | 
|  | 41 | + | 
| 17 | 42 | def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: | 
| 18 | 43 |     """ | 
| 19 |  | -    Find chains of view_copy nodes and merge them into one view_copy node. | 
|  | 44 | +    Find chains of view_copy nodes and unary elementwise ops and set all | 
|  | 45 | +    view_copy nodes to have the final shape. The views will then be removed | 
|  | 46 | +    by the remove_noop_view_copy call. | 
|  | 47 | +
 | 
| 20 | 48 |     Only merges view_copy nodes that are not used by any other nodes. | 
| 21 | 49 |     """ | 
| 22 | 50 |     ops = exir_ops.edge | 
| 23 | 51 |     view_op = ops.aten.view_copy.default | 
| 24 | 52 |     modified = False | 
| 25 | 53 |     for node in graph.nodes: | 
| 26 | 54 |         if node.op == "call_function" and node.target == view_op: | 
| 27 |  | -            # find ending view_copy node in chain | 
|  | 55 | +            # Find a chain of unary elementwise ops and save all view_copy nodes | 
| 28 | 56 |             end_node = node | 
|  | 57 | +            view_ops = [node] | 
| 29 | 58 |             while ( | 
| 30 | 59 |                 end_node.op == "call_function" | 
| 31 |  | -                and end_node.target == view_op | 
|  | 60 | +                and end_node.target in UNARY_ELEMENTWISE_OPS | 
| 32 | 61 |                 and len(end_node.users) == 1 | 
| 33 |  | -                and list(end_node.users)[0].target == view_op | 
|  | 62 | +                and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS | 
| 34 | 63 |             ): | 
| 35 | 64 |                 end_node = list(end_node.users)[0] | 
| 36 |  | -            # we can swap the first node's shape arg with the last node's shape arg | 
| 37 |  | -            if node != end_node: | 
| 38 |  | -                with graph.inserting_after(node): | 
| 39 |  | -                    new_args = (node.args[0], end_node.args[1]) | 
|  | 65 | +                if end_node.target == view_op: | 
|  | 66 | +                    view_ops.append(end_node) | 
|  | 67 | + | 
|  | 68 | +            # Set all view_copy nodes to have the final shape | 
|  | 69 | +            if len(view_ops) > 1: | 
|  | 70 | +                final_shape = view_ops[-1].args[1] | 
|  | 71 | +                for node in view_ops: | 
|  | 72 | +                    new_args = (node.args[0], final_shape) | 
| 40 | 73 |                     node.args = new_args | 
| 41 |  | -                    end_node.replace_all_uses_with(node) | 
| 42 | 74 |                 modified = True | 
| 43 | 75 | 
 | 
| 44 | 76 |     graph.eliminate_dead_code() | 
| @@ -67,10 +99,14 @@ class FuseViewCopyTransform(ExportPass): | 
| 67 | 99 |     _passes_required_after: Set[Type[ExportPass]] = set() | 
| 68 | 100 | 
 | 
| 69 | 101 |     def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | 
| 70 |  | -        graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) | 
| 71 |  | -        graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph) | 
| 72 |  | -        modified = merge_modified or noop_modified | 
|  | 102 | +        graph_module.graph, modified = merge_view_copy_chains(graph_module.graph) | 
| 73 | 103 |         if modified: | 
| 74 | 104 |             graph_module.recompile() | 
| 75 | 105 |             graph_module = super().call(graph_module).graph_module | 
|  | 106 | + | 
|  | 107 | +        graph_module.graph, modified = remove_noop_view_copy(graph_module.graph) | 
|  | 108 | +        if modified: | 
|  | 109 | +            graph_module.recompile() | 
|  | 110 | +            graph_module = super().call(graph_module).graph_module | 
|  | 111 | + | 
| 76 | 112 |         return PassResult(graph_module, modified) | 
0 commit comments