@@ -276,9 +276,11 @@ def test_remove_nop_view(self, shape: Tuple[int], new_shape: List[int]) -> None:
276276 )
277277 builder .output ([view ])
278278 original = builder .get_graph_module ()
279+ pass_instance = RemoveNopSliceOrViewOpPass ()
279280 graph_after_passes = cast (
280- PassResult , RemoveNopSliceOrViewOpPass () (original )
281+ PassResult , pass_instance (original )
281282 ).graph_module
283+ self .assertTrue (pass_instance ._modified )
282284 self .assertEqual (
283285 count_node (graph_after_passes , exir_ops .edge .aten .view_copy .default ), 0
284286 )
@@ -297,13 +299,34 @@ def test_remove_nop_slice(self) -> None:
297299 )
298300 builder .output ([slice_ ])
299301 original = builder .get_graph_module ()
302+ pass_instance = RemoveNopSliceOrViewOpPass ()
300303 graph_after_passes = cast (
301- PassResult , RemoveNopSliceOrViewOpPass () (original )
304+ PassResult , pass_instance (original )
302305 ).graph_module
306+ self .assertTrue (pass_instance ._modified )
303307 self .assertEqual (
304308 count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
305309 )
306310
311+ def test_remove_nop_slice_or_view_not_modified (self ) -> None :
312+ builder = GraphBuilder ()
313+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
314+ abs_x = builder .call_operator (
315+ op = exir_ops .edge .aten .abs .default ,
316+ args = (x ,),
317+ )
318+ builder .output ([abs_x ])
319+ original = builder .get_graph_module ()
320+ pass_instance = RemoveNopSliceOrViewOpPass ()
321+ graph_after_passes = cast (
322+ PassResult , pass_instance (original )
323+ ).graph_module
324+ self .assertFalse (pass_instance ._modified )
325+ self .assertEqual (
326+ count_node (graph_after_passes , exir_ops .edge .aten .abs .default ), 1
327+ )
328+
329+
307330 def test_remove_nop_select_before_view (self ) -> None :
308331 builder = GraphBuilder ()
309332 x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
0 commit comments