Skip to content

Commit 13e6a80

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Properly set modified for RemoveNopSliceOrViewOpPass (pytorch#15470)
Summary: Ensure modified is being correctly set for RemoveNopSliceOrViewOpPass. This diff also introduces a common base class that we can use for all remove ops passes. Reviewed By: abeakkas Differential Revision: D85815510
1 parent be9fc4d commit 13e6a80

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
# pyre-strict
88

9-
109
import logging
10+
from abc import abstractmethod
1111
from dataclasses import dataclass, field
1212
from typing import cast, List, Optional, Sequence, Set, Type
1313

@@ -35,6 +35,45 @@
3535
from torch.fx.node import Argument, Node
3636

3737

38+
class RemovePassCommon(ExportPass):
39+
@property
40+
@abstractmethod
41+
def targets(self) -> list[EdgeOpOverload]:
42+
"""
43+
The list of targets to potentially remove.
44+
"""
45+
raise NotImplementedError("`targets` must be defined")
46+
47+
@abstractmethod
48+
def maybe_remove(self, node: Node) -> bool:
49+
"""
50+
If the node should be removed, removes from the graph. Returns
51+
True if the graph was modified, else False.
52+
"""
53+
raise NotImplementedError("`maybe_remove` must be defined")
54+
55+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
56+
"""
57+
For each node in targets, if the node should be removed, removes from
58+
the graph and returns the modified graph and modified set to True.
59+
If no node should be removed, returns a pass result with the original
60+
graph module and False for modified.
61+
"""
62+
changed = False
63+
for target in self.targets:
64+
for module in graph_module.modules():
65+
assert isinstance(module, torch.fx.GraphModule)
66+
for node in module.graph.find_nodes(op="call_function", target=target):
67+
changed |= self.maybe_remove(node)
68+
69+
if changed:
70+
graph_module.graph.eliminate_dead_code()
71+
graph_module.recompile()
72+
return super().call(graph_module)
73+
74+
return PassResult(graph_module, False)
75+
76+
3877
@register_cadence_pass(CadencePassAttribute(opt_level=0))
3978
class RemoveCloneOpsTransformImported(ExportPass):
4079
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
@@ -186,33 +225,27 @@ def call_operator(
186225

187226

188227
@register_cadence_pass(CadencePassAttribute(opt_level=1))
189-
class RemoveNopSliceOrViewOpPass(ExportPass):
228+
class RemoveNopSliceOrViewOpPass(RemovePassCommon):
190229
"""
191230
Remove slice ops that are more like views, and view ops that do not change the shape
192231
"""
193232

194-
def call_operator(
195-
self,
196-
op, # pyre-ignore
197-
args: tuple[Argument, ...],
198-
kwargs: dict[str, Argument],
199-
meta: NodeMetadata,
200-
) -> ProxyValue:
201-
if op not in {
233+
@property
234+
def targets(self) -> list[EdgeOpOverload]:
235+
return [
202236
exir_ops.edge.aten.slice_copy.Tensor,
203237
exir_ops.edge.aten.view_copy.default,
204-
}:
205-
return super().call_operator(op, args, kwargs, meta)
238+
]
206239

207-
arg0 = cast(ProxyValue, args[0])
208-
out_shape = meta["val"].shape
240+
def maybe_remove(self, node: Node) -> bool:
241+
changed = False
242+
input_node = node.args[0]
243+
assert isinstance(input_node, Node)
244+
if input_node.meta["val"].shape == node.meta["val"].shape:
245+
node.replace_all_uses_with(input_node)
246+
changed = True
209247

210-
# If both arg_shape and out_shape are the same, this slice is a nop
211-
return (
212-
arg0
213-
if arg0.to_tensor().shape == out_shape
214-
else super().call_operator(op, args, kwargs, meta)
215-
)
248+
return changed
216249

217250

218251
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,22 @@ def test_remove_nop_slice(self) -> None:
304304
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
305305
)
306306

307+
def test_remove_nop_slice_or_view_not_modified(self) -> None:
308+
builder = GraphBuilder()
309+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
310+
abs_x = builder.call_operator(
311+
op=exir_ops.edge.aten.abs.default,
312+
args=(x,),
313+
)
314+
builder.output([abs_x])
315+
original = builder.get_graph_module()
316+
pass_result = cast(PassResult, RemoveNopSliceOrViewOpPass()(original))
317+
self.assertFalse(pass_result.modified)
318+
graph_after_passes = pass_result.graph_module
319+
self.assertEqual(
320+
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
321+
)
322+
307323
def test_remove_nop_select_before_view(self) -> None:
308324
builder = GraphBuilder()
309325
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))

0 commit comments

Comments
 (0)