|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +# pyre-strict | 
|  | 8 | + | 
|  | 9 | +from typing import Set, Union | 
|  | 10 | + | 
|  | 11 | +import torch | 
|  | 12 | +from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 13 | +from executorch.exir.dialects.edge._ops import EdgeOpOverload | 
|  | 14 | +from executorch.exir.pass_base import ExportPass, PassResult | 
|  | 15 | +from executorch.exir.passes import dead_code_elimination_pass | 
|  | 16 | + | 
|  | 17 | +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +class RemoveRedundantOpsTransform(ExportPass): | 
|  | 21 | +    """ | 
|  | 22 | +    Trim certain operators to reduce unnecessary overhead. | 
|  | 23 | +    """ | 
|  | 24 | + | 
|  | 25 | +    redundant_ops: Set[OpType] = { | 
|  | 26 | +        torch.clone, | 
|  | 27 | +        torch.ops.aten.clone.default, | 
|  | 28 | +        exir_ops.edge.aten.clone.default, | 
|  | 29 | +        torch.ops.aten.alias.default, | 
|  | 30 | +        exir_ops.edge.aten.alias.default, | 
|  | 31 | +        exir_ops.edge.aten.lift_fresh_copy.default, | 
|  | 32 | +        exir_ops.edge.dim_order_ops._to_dim_order_copy.default, | 
|  | 33 | +    } | 
|  | 34 | + | 
|  | 35 | +    def __init__(self) -> None: | 
|  | 36 | +        super(RemoveRedundantOpsTransform, self).__init__() | 
|  | 37 | + | 
|  | 38 | +    def _should_remove(self, node: torch.fx.Node) -> bool: | 
|  | 39 | +        if node.target in self.redundant_ops: | 
|  | 40 | +            return True | 
|  | 41 | + | 
|  | 42 | +        # Only remove to_copy if dtype does not change. Otherwise, memory format changes | 
|  | 43 | +        # will be handled internally by the backend. | 
|  | 44 | +        if ( | 
|  | 45 | +            node.target == exir_ops.edge.aten._to_copy.default | 
|  | 46 | +            or node.target == torch.ops.aten._to_copy.default | 
|  | 47 | +        ): | 
|  | 48 | +            src_dtype = node.meta["val"].dtype | 
|  | 49 | +            # pyre-ignore | 
|  | 50 | +            dst_dtype = node.args[0].meta["val"].dtype | 
|  | 51 | +            return src_dtype == dst_dtype | 
|  | 52 | + | 
|  | 53 | +        return False | 
|  | 54 | + | 
|  | 55 | +    def _remove(self, graph_module: torch.fx.GraphModule) -> None: | 
|  | 56 | +        for node in graph_module.graph.nodes: | 
|  | 57 | +            if not self._should_remove(node): | 
|  | 58 | +                continue | 
|  | 59 | + | 
|  | 60 | +            with graph_module.graph.inserting_after(node): | 
|  | 61 | +                node.replace_all_uses_with(node.args[0]) | 
|  | 62 | + | 
|  | 63 | +        graph_module.graph.eliminate_dead_code() | 
|  | 64 | + | 
|  | 65 | +    def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | 
|  | 66 | +        self._remove(graph_module) | 
|  | 67 | +        graph_module.recompile() | 
|  | 68 | +        dead_code_elimination_pass(graph_module) | 
|  | 69 | +        return PassResult(graph_module, True) | 
0 commit comments