Skip to content

Commit bc04928

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Properly set modified for RemoveNopSliceOrViewOpPass (#15470)
Summary: Ensure modified is being correctly set for RemoveNopSliceOrViewOpPass Differential Revision: D85815510
1 parent 5b1004c commit bc04928

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ class RemoveNopSliceOrViewOpPass(ExportPass):
191191
Remove slice ops that are more like views, and view ops that do not change the shape
192192
"""
193193

194+
@staticmethod
195+
def _is_nop(input_shape: tuple[int, ...], output_shape: tuple[int, ...]) -> bool:
196+
# If both arg_shape and out_shape are the same, this slice is a nop
197+
return input_shape == output_shape
198+
194199
def call_operator(
195200
self,
196201
op, # pyre-ignore
@@ -207,13 +212,27 @@ def call_operator(
207212
arg0 = cast(ProxyValue, args[0])
208213
out_shape = meta["val"].shape
209214

210-
# If both arg_shape and out_shape are the same, this slice is a nop
211215
return (
212216
arg0
213-
if arg0.to_tensor().shape == out_shape
217+
if self._is_nop(arg0.to_tensor().shape, out_shape)
214218
else super().call_operator(op, args, kwargs, meta)
215219
)
216220

221+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
222+
for target in [
223+
exir_ops.edge.aten.slice_copy.Tensor,
224+
exir_ops.edge.aten.view_copy.default,
225+
]:
226+
for node in graph_module.graph.find_nodes(
227+
op="call_function", target=target
228+
):
229+
input_node = node.args[0]
230+
assert isinstance(input_node, torch.fx.Node)
231+
if self._is_nop(input_node.meta["val"].shape, node.meta["val"].shape):
232+
return super().call(graph_module)
233+
234+
return PassResult(graph_module, False)
235+
217236

218237
@register_cadence_pass(CadencePassAttribute(opt_level=1))
219238
class RemoveNopLinalgVectorNormOpPass(ExportPass):

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def test_remove_nop_view(self, shape: Tuple[int], new_shape: List[int]) -> None:
279279
graph_after_passes = cast(
280280
PassResult, RemoveNopSliceOrViewOpPass()(original)
281281
).graph_module
282+
assert original is not graph_after_passes
282283
self.assertEqual(
283284
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
284285
)
@@ -300,10 +301,29 @@ def test_remove_nop_slice(self) -> None:
300301
graph_after_passes = cast(
301302
PassResult, RemoveNopSliceOrViewOpPass()(original)
302303
).graph_module
304+
assert original is not graph_after_passes
303305
self.assertEqual(
304306
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
305307
)
306308

309+
def test_remove_nop_slice_or_view_not_modified(self) -> None:
310+
builder = GraphBuilder()
311+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
312+
abs_x = builder.call_operator(
313+
op=exir_ops.edge.aten.abs.default,
314+
args=(x,),
315+
)
316+
builder.output([abs_x])
317+
original = builder.get_graph_module()
318+
graph_after_passes = cast(
319+
PassResult, RemoveNopSliceOrViewOpPass()(original)
320+
).graph_module
321+
# If the graph is not modified, the same object should be returned
322+
assert original is graph_after_passes
323+
self.assertEqual(
324+
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
325+
)
326+
307327
def test_remove_nop_select_before_view(self) -> None:
308328
builder = GraphBuilder()
309329
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))

0 commit comments

Comments
 (0)