Skip to content

Commit db29d3a

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Properly set modified for RemoveNopSliceOrViewOpPass
Summary: Ensure modified is being correctly set for RemoveNopSliceOrViewOpPass Differential Revision: D85815510
1 parent 9d68039 commit db29d3a

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ class RemoveNopSliceOrViewOpPass(ExportPass):
191191
Remove slice ops that are more like views, and view ops that do not change the shape
192192
"""
193193

194+
def __init__(self) -> None:
195+
super().__init__()
196+
self._modified = False
197+
194198
def call_operator(
195199
self,
196200
op, # pyre-ignore
@@ -204,6 +208,7 @@ def call_operator(
204208
}:
205209
return super().call_operator(op, args, kwargs, meta)
206210

211+
self._modified = True
207212
arg0 = cast(ProxyValue, args[0])
208213
out_shape = meta["val"].shape
209214

@@ -214,6 +219,11 @@ def call_operator(
214219
else super().call_operator(op, args, kwargs, meta)
215220
)
216221

222+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
223+
self._modified = False
224+
result = super().call(graph_module)
225+
return PassResult(result.graph_module, self._modified)
226+
217227

218228
@register_cadence_pass(CadencePassAttribute(opt_level=1))
219229
class RemoveNopLinalgVectorNormOpPass(ExportPass):

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,11 @@ def test_remove_nop_view(self, shape: Tuple[int], new_shape: List[int]) -> None:
276276
)
277277
builder.output([view])
278278
original = builder.get_graph_module()
279+
pass_instance = RemoveNopSliceOrViewOpPass()
279280
graph_after_passes = cast(
280-
PassResult, RemoveNopSliceOrViewOpPass()(original)
281+
PassResult, pass_instance(original)
281282
).graph_module
283+
self.assertTrue(pass_instance._modified)
282284
self.assertEqual(
283285
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
284286
)
@@ -297,13 +299,34 @@ def test_remove_nop_slice(self) -> None:
297299
)
298300
builder.output([slice_])
299301
original = builder.get_graph_module()
302+
pass_instance = RemoveNopSliceOrViewOpPass()
300303
graph_after_passes = cast(
301-
PassResult, RemoveNopSliceOrViewOpPass()(original)
304+
PassResult, pass_instance(original)
302305
).graph_module
306+
self.assertTrue(pass_instance._modified)
303307
self.assertEqual(
304308
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
305309
)
306310

311+
def test_remove_nop_slice_or_view_not_modified(self) -> None:
312+
builder = GraphBuilder()
313+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
314+
abs_x = builder.call_operator(
315+
op=exir_ops.edge.aten.abs.default,
316+
args=(x,),
317+
)
318+
builder.output([abs_x])
319+
original = builder.get_graph_module()
320+
pass_instance = RemoveNopSliceOrViewOpPass()
321+
graph_after_passes = cast(
322+
PassResult, pass_instance(original)
323+
).graph_module
324+
self.assertFalse(pass_instance._modified)
325+
self.assertEqual(
326+
count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1
327+
)
328+
329+
307330
def test_remove_nop_select_before_view(self) -> None:
308331
builder = GraphBuilder()
309332
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))

0 commit comments

Comments
 (0)