Skip to content

Commit f4e1bd0

Browse files
authored
Properly set modified for RemoveNopSliceOrViewOpPass
Differential Revision: D85815510 Pull Request resolved: #15470
1 parent 465dce2 commit f4e1bd0

File tree

3 files changed

+74
-21
lines changed

3 files changed

+74
-21
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66

77
# pyre-strict
88

9+
from abc import abstractmethod
910
from dataclasses import dataclass
1011
from typing import Callable, List, Optional, Set, Type, Union
1112

1213
import torch
1314
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
1415

1516
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
16-
from executorch.exir.pass_base import PassBase, PassResult
17+
from executorch.exir.pass_base import ExportPass, PassBase, PassResult
1718

1819
from torch._ops import OpOverloadPacket
20+
from torch.fx import Node
1921

2022

2123
# Is an overlap in tensor lifetime and storage allowed at the current opt level?
@@ -229,3 +231,44 @@ def set_arg(
229231
def none_throws(x: Optional[PassResult]) -> PassResult:
230232
assert x is not None
231233
return x
234+
235+
236+
class RemoveOrReplacePassInterface(ExportPass):
237+
@property
238+
@abstractmethod
239+
def targets(self) -> list[EdgeOpOverload]:
240+
"""
241+
The list of targets to potentially remove or replace.
242+
"""
243+
raise NotImplementedError("`targets` must be implemented")
244+
245+
@abstractmethod
246+
def maybe_remove_or_replace(self, node: Node) -> bool:
247+
"""
248+
If the node should be removed/replaced, removes/replaces from the graph. Returns
249+
True if the graph was modified, else False.
250+
"""
251+
raise NotImplementedError("`maybe_remove_or_replace` must be implemented")
252+
253+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
254+
"""
255+
For each node in targets, if the node should be removed/replaced,
256+
removes/replaces from the graph and returns the modified graph and modified
257+
set to True.
258+
If no node should be removed/replaced, returns a pass result with the original
259+
graph module and False for modified.
260+
"""
261+
changed = False
262+
for target in self.targets:
263+
for module in filter(
264+
lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules()
265+
):
266+
for node in module.graph.find_nodes(op="call_function", target=target):
267+
changed |= self.maybe_remove_or_replace(node)
268+
269+
if changed:
270+
graph_module.graph.eliminate_dead_code()
271+
graph_module.recompile()
272+
return super().call(graph_module)
273+
274+
return PassResult(graph_module, False)

backends/cadence/aot/remove_ops.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
109
import logging
1110
from dataclasses import dataclass, field
1211
from typing import cast, List, Optional, Sequence, Set, Type
@@ -20,6 +19,7 @@
2019
CadencePassAttribute,
2120
get_arg,
2221
register_cadence_pass,
22+
RemoveOrReplacePassInterface,
2323
set_arg,
2424
)
2525

@@ -186,33 +186,27 @@ def call_operator(
186186

187187

188188
@register_cadence_pass(CadencePassAttribute(opt_level=1))
189-
class RemoveNopSliceOrViewOpPass(ExportPass):
189+
class RemoveNopSliceOrViewOpPass(RemoveOrReplacePassInterface):
190190
"""
191191
Remove slice ops that are more like views, and view ops that do not change the shape
192192
"""
193193

194-
def call_operator(
195-
self,
196-
op, # pyre-ignore
197-
args: tuple[Argument, ...],
198-
kwargs: dict[str, Argument],
199-
meta: NodeMetadata,
200-
) -> ProxyValue:
201-
if op not in {
194+
@property
195+
def targets(self) -> list[EdgeOpOverload]:
196+
return [
202197
exir_ops.edge.aten.slice_copy.Tensor,
203198
exir_ops.edge.aten.view_copy.default,
204-
}:
205-
return super().call_operator(op, args, kwargs, meta)
199+
]
206200

207-
arg0 = cast(ProxyValue, args[0])
208-
out_shape = meta["val"].shape
201+
def maybe_remove_or_replace(self, node: Node) -> bool:
202+
changed = False
203+
input_node = node.args[0]
204+
assert isinstance(input_node, Node)
205+
if input_node.meta["val"].shape == node.meta["val"].shape:
206+
node.replace_all_uses_with(input_node)
207+
changed = True
209208

210-
# If both arg_shape and out_shape are the same, this slice is a nop
211-
return (
212-
arg0
213-
if arg0.to_tensor().shape == out_shape
214-
else super().call_operator(op, args, kwargs, meta)
215-
)
209+
return changed
216210

217211

218212
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,22 @@ def test_remove_nop_slice(self) -> None:
304304
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
305305
)
306306

307+
def test_remove_nop_slice_or_view_not_modified(self) -> None:
308+
builder = GraphBuilder()
309+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
310+
abs_x = builder.call_operator(
311+
op=exir_ops.edge.aten.abs.default,
312+
args=(x,),
313+
)
314+
builder.output([abs_x])
315+
original = builder.get_graph_module()
316+
pass_result = cast(PassResult, RemoveNopSliceOrViewOpPass()(original))
317+
self.assertFalse(pass_result.modified)
318+
graph_after_passes = pass_result.graph_module
319+
self.assertEqual(
320+
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
321+
)
322+
307323
def test_remove_nop_select_before_view(self) -> None:
308324
builder = GraphBuilder()
309325
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))

0 commit comments

Comments
 (0)