Skip to content

Commit aef19e1

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Properly set modified for RemoveNopSliceOrViewOpPass (pytorch#15470)
Summary: Ensure modified is being correctly set for RemoveNopSliceOrViewOpPass. This diff also introduces a common base class that we can use for all remove ops passes. Reviewed By: abeakkas Differential Revision: D85815510
1 parent 4a75896 commit aef19e1

File tree

2 files changed

+70
-20
lines changed

2 files changed

+70
-20
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
9+
from abc import abstractmethod
1010
import logging
1111
from dataclasses import dataclass, field
1212
from typing import cast, List, Optional, Sequence, Set, Type
@@ -34,6 +34,44 @@
3434
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3535
from torch.fx.node import Argument, Node
3636

37+
class RemovePassCommon(ExportPass):
38+
@property
39+
@abstractmethod
40+
def targets(self) -> list[EdgeOpOverload]:
41+
"""
42+
The list of targets to potentially remove.
43+
"""
44+
...
45+
46+
@abstractmethod
47+
def maybe_remove(self, node: Node) -> bool:
48+
"""
49+
If the node should be removed, removes from the graph. Returns
50+
True if the graph was modified, else False.
51+
"""
52+
...
53+
54+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
55+
"""
56+
For each node in targets, if the node should be removed, removes from
57+
the graph and returns the modified graph and modified set to True.
58+
If no node should be removed, returns a pass result with the original
59+
graph module and False for modified.
60+
"""
61+
changed = False
62+
for target in self.targets:
63+
for node in graph_module.graph.find_nodes(op="call_function", target=target):
64+
changed = changed or self.maybe_remove(node)
65+
66+
if changed:
67+
graph_module.graph.eliminate_dead_code()
68+
graph_module.recompile()
69+
return super().call(graph_module)
70+
71+
return PassResult(graph_module, False)
72+
73+
74+
3775

3876
@register_cadence_pass(CadencePassAttribute(opt_level=0))
3977
class RemoveCloneOpsTransformImported(ExportPass):
@@ -186,33 +224,27 @@ def call_operator(
186224

187225

188226
@register_cadence_pass(CadencePassAttribute(opt_level=1))
189-
class RemoveNopSliceOrViewOpPass(ExportPass):
227+
class RemoveNopSliceOrViewOpPass(RemovePassCommon):
190228
"""
191229
Remove slice ops that are more like views, and view ops that do not change the shape
192230
"""
193231

194-
def call_operator(
195-
self,
196-
op, # pyre-ignore
197-
args: tuple[Argument, ...],
198-
kwargs: dict[str, Argument],
199-
meta: NodeMetadata,
200-
) -> ProxyValue:
201-
if op not in {
232+
@property
233+
def targets(self) -> list[EdgeOpOverload]:
234+
return [
202235
exir_ops.edge.aten.slice_copy.Tensor,
203236
exir_ops.edge.aten.view_copy.default,
204-
}:
205-
return super().call_operator(op, args, kwargs, meta)
237+
]
206238

207-
arg0 = cast(ProxyValue, args[0])
208-
out_shape = meta["val"].shape
239+
def maybe_remove(self, node: Node) -> bool:
240+
changed = False
241+
input_node = node.args[0]
242+
assert isinstance(input_node, Node)
243+
if input_node.meta["val"].shape == node.meta["val"].shape:
244+
node.replace_all_uses_with(input_node)
245+
changed = True
209246

210-
# If both arg_shape and out_shape are the same, this slice is a nop
211-
return (
212-
arg0
213-
if arg0.to_tensor().shape == out_shape
214-
else super().call_operator(op, args, kwargs, meta)
215-
)
247+
return changed
216248

217249

218250
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,24 @@ def test_remove_nop_slice(self) -> None:
304304
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
305305
)
306306

307+
def test_remove_nop_slice_or_view_not_modified(self) -> None:
308+
builder = GraphBuilder()
309+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
310+
abs_x = builder.call_operator(
311+
op=exir_ops.edge.aten.abs.default,
312+
args=(x,),
313+
)
314+
builder.output([abs_x])
315+
original = builder.get_graph_module()
316+
pass_result = cast(
317+
PassResult, RemoveNopSliceOrViewOpPass()(original)
318+
)
319+
self.assertFalse(pass_result.modified)
320+
graph_after_passes = pass_result.graph_module
321+
self.assertEqual(
322+
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
323+
)
324+
307325
def test_remove_nop_select_before_view(self) -> None:
308326
builder = GraphBuilder()
309327
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))

0 commit comments

Comments
 (0)