Skip to content

Commit 27dc313

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Properly set modified for RemoveNopSliceOrViewOpPass (#15470)
Summary: Ensure modified is being correctly set for RemoveNopSliceOrViewOpPass Reviewed By: abeakkas Differential Revision: D85815510
1 parent 5b1004c commit 27dc313

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ class RemoveNopSliceOrViewOpPass(ExportPass):
191191
Remove slice ops that are more like views, and view ops that do not change the shape
192192
"""
193193

194+
@staticmethod
195+
def _is_nop(input_shape: tuple[int, ...], output_shape: tuple[int, ...]) -> bool:
196+
# If both arg_shape and out_shape are the same, this slice is a nop
197+
return input_shape == output_shape
198+
194199
def call_operator(
195200
self,
196201
op, # pyre-ignore
@@ -207,13 +212,27 @@ def call_operator(
207212
arg0 = cast(ProxyValue, args[0])
208213
out_shape = meta["val"].shape
209214

210-
# If both arg_shape and out_shape are the same, this slice is a nop
211215
return (
212216
arg0
213-
if arg0.to_tensor().shape == out_shape
217+
if self._is_nop(arg0.to_tensor().shape, out_shape)
214218
else super().call_operator(op, args, kwargs, meta)
215219
)
216220

221+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
222+
for target in [
223+
exir_ops.edge.aten.slice_copy.Tensor,
224+
exir_ops.edge.aten.view_copy.default,
225+
]:
226+
for node in graph_module.graph.find_nodes(
227+
op="call_function", target=target
228+
):
229+
input_node = node.args[0]
230+
assert isinstance(input_node, torch.fx.Node)
231+
if self._is_nop(input_node.meta["val"].shape, node.meta["val"].shape):
232+
return super().call(graph_module)
233+
234+
return PassResult(graph_module, False)
235+
217236

218237
@register_cadence_pass(CadencePassAttribute(opt_level=1))
219238
class RemoveNopLinalgVectorNormOpPass(ExportPass):

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,24 @@ def test_remove_nop_slice(self) -> None:
304304
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
305305
)
306306

307+
def test_remove_nop_slice_or_view_not_modified(self) -> None:
308+
builder = GraphBuilder()
309+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
310+
abs_x = builder.call_operator(
311+
op=exir_ops.edge.aten.abs.default,
312+
args=(x,),
313+
)
314+
builder.output([abs_x])
315+
original = builder.get_graph_module()
316+
pass_result = cast(
317+
PassResult, RemoveNopSliceOrViewOpPass()(original)
318+
)
319+
self.assertFalse(pass_result.modified)
320+
graph_after_passes = pass_result.graph_module
321+
self.assertEqual(
322+
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
323+
)
324+
307325
def test_remove_nop_select_before_view(self) -> None:
308326
builder = GraphBuilder()
309327
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))

0 commit comments

Comments
 (0)