Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

# pyre-strict

from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, List, Optional, Set, Type, Union

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet

from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.pass_base import PassBase, PassResult
from executorch.exir.pass_base import ExportPass, PassBase, PassResult

from torch._ops import OpOverloadPacket
from torch.fx import Node


# Is an overlap in tensor lifetime and storage allowed at the current opt level?
Expand Down Expand Up @@ -229,3 +231,44 @@ def set_arg(
def none_throws(x: Optional[PassResult]) -> PassResult:
assert x is not None
return x


class RemoveOrReplacePassInterface(ExportPass):
@property
@abstractmethod
def targets(self) -> list[EdgeOpOverload]:
"""
The list of targets to potentially remove or replace.
"""
raise NotImplementedError("`targets` must be implemented")

@abstractmethod
def maybe_remove_or_replace(self, node: Node) -> bool:
"""
If the node should be removed/replaced, removes/replaces from the graph. Returns
True if the graph was modified, else False.
"""
raise NotImplementedError("`maybe_remove_or_replace` must be implemented")

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
"""
For each node in targets, if the node should be removed/replaced,
removes/replaces from the graph and returns the modified graph and modified
set to True.
If no node should be removed/replaced, returns a pass result with the original
graph module and False for modified.
"""
changed = False
for target in self.targets:
for module in filter(
lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules()
):
for node in module.graph.find_nodes(op="call_function", target=target):
changed |= self.maybe_remove_or_replace(node)

if changed:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return super().call(graph_module)

return PassResult(graph_module, False)
34 changes: 14 additions & 20 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict


import logging
from dataclasses import dataclass, field
from typing import cast, List, Optional, Sequence, Set, Type
Expand All @@ -20,6 +19,7 @@
CadencePassAttribute,
get_arg,
register_cadence_pass,
RemoveOrReplacePassInterface,
set_arg,
)

Expand Down Expand Up @@ -186,33 +186,27 @@ def call_operator(


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopSliceOrViewOpPass(ExportPass):
class RemoveNopSliceOrViewOpPass(RemoveOrReplacePassInterface):
"""
Remove slice ops that are more like views, and view ops that do not change the shape
"""

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.view_copy.default,
}:
return super().call_operator(op, args, kwargs, meta)
]

arg0 = cast(ProxyValue, args[0])
out_shape = meta["val"].shape
def maybe_remove_or_replace(self, node: Node) -> bool:
changed = False
input_node = node.args[0]
assert isinstance(input_node, Node)
if input_node.meta["val"].shape == node.meta["val"].shape:
node.replace_all_uses_with(input_node)
changed = True

# If both arg_shape and out_shape are the same, this slice is a nop
return (
arg0
if arg0.to_tensor().shape == out_shape
else super().call_operator(op, args, kwargs, meta)
)
return changed


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down
16 changes: 16 additions & 0 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,22 @@ def test_remove_nop_slice(self) -> None:
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
)

def test_remove_nop_slice_or_view_not_modified(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
abs_x = builder.call_operator(
op=exir_ops.edge.aten.abs.default,
args=(x,),
)
builder.output([abs_x])
original = builder.get_graph_module()
pass_result = cast(PassResult, RemoveNopSliceOrViewOpPass()(original))
self.assertFalse(pass_result.modified)
graph_after_passes = pass_result.graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
)

def test_remove_nop_select_before_view(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
Expand Down
Loading