diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 9aedef2ce2f..96c30bcdf59 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -6,6 +6,7 @@ # pyre-strict +from abc import abstractmethod from dataclasses import dataclass from typing import Callable, List, Optional, Set, Type, Union @@ -13,9 +14,10 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import PassBase, PassResult +from executorch.exir.pass_base import ExportPass, PassBase, PassResult from torch._ops import OpOverloadPacket +from torch.fx import Node # Is an overlap in tensor lifetime and storage allowed at the current opt level? @@ -229,3 +231,44 @@ def set_arg( def none_throws(x: Optional[PassResult]) -> PassResult: assert x is not None return x + + +class RemoveOrReplacePassInterface(ExportPass): + @property + @abstractmethod + def targets(self) -> list[EdgeOpOverload]: + """ + The list of targets to potentially remove or replace. + """ + raise NotImplementedError("`targets` must be implemented") + + @abstractmethod + def maybe_remove_or_replace(self, node: Node) -> bool: + """ + If the node should be removed/replaced, removes/replaces from the graph. Returns + True if the graph was modified, else False. + """ + raise NotImplementedError("`maybe_remove_or_replace` must be implemented") + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + """ + For each node in targets, if the node should be removed/replaced, + removes/replaces from the graph and returns the modified graph and modified + set to True. + If no node should be removed/replaced, returns a pass result with the original + graph module and False for modified. + """ + changed = False + for target in self.targets: + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + for node in module.graph.find_nodes(op="call_function", target=target): + changed |= self.maybe_remove_or_replace(node) + + if changed: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 2e2aa5fc17a..289c5ffeeec 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -6,7 +6,6 @@ # pyre-strict - import logging from dataclasses import dataclass, field from typing import cast, List, Optional, Sequence, Set, Type @@ -20,6 +19,7 @@ CadencePassAttribute, get_arg, register_cadence_pass, + RemoveOrReplacePassInterface, set_arg, ) @@ -186,33 +186,27 @@ def call_operator( @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopSliceOrViewOpPass(ExportPass): +class RemoveNopSliceOrViewOpPass(RemoveOrReplacePassInterface): """ Remove slice ops that are more like views, and view ops that do not change the shape """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.view_copy.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] - arg0 = cast(ProxyValue, args[0]) - out_shape = meta["val"].shape + def maybe_remove_or_replace(self, node: Node) -> bool: + changed = False + input_node = node.args[0] + assert isinstance(input_node, Node) + if input_node.meta["val"].shape == node.meta["val"].shape: + node.replace_all_uses_with(input_node) + changed = True - # If both arg_shape and out_shape are the same, this slice is a nop - return ( - arg0 - if arg0.to_tensor().shape == out_shape - else super().call_operator(op, args, kwargs, meta) - ) + return changed @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index a38416c0ff1..483d737f97d 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -304,6 +304,22 @@ def test_remove_nop_slice(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 ) + def test_remove_nop_slice_or_view_not_modified(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) + abs_x = builder.call_operator( + op=exir_ops.edge.aten.abs.default, + args=(x,), + ) + builder.output([abs_x]) + original = builder.get_graph_module() + pass_result = cast(PassResult, RemoveNopSliceOrViewOpPass()(original)) + self.assertFalse(pass_result.modified) + graph_after_passes = pass_result.graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1 + ) + def test_remove_nop_select_before_view(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))