@@ -279,6 +279,7 @@ def test_remove_nop_view(self, shape: Tuple[int], new_shape: List[int]) -> None:
279279 graph_after_passes = cast (
280280 PassResult , RemoveNopSliceOrViewOpPass ()(original )
281281 ).graph_module
282+ assert original is not graph_after_passes
282283 self .assertEqual (
283284 count_node (graph_after_passes , exir_ops .edge .aten .view_copy .default ), 0
284285 )
@@ -300,10 +301,29 @@ def test_remove_nop_slice(self) -> None:
300301 graph_after_passes = cast (
301302 PassResult , RemoveNopSliceOrViewOpPass ()(original )
302303 ).graph_module
304+ assert original is not graph_after_passes
303305 self .assertEqual (
304306 count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
305307 )
306308
309+ def test_remove_nop_slice_or_view_not_modified (self ) -> None :
310+ builder = GraphBuilder ()
311+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
312+ abs_x = builder .call_operator (
313+ op = exir_ops .edge .aten .abs .default ,
314+ args = (x ,),
315+ )
316+ builder .output ([abs_x ])
317+ original = builder .get_graph_module ()
318+ graph_after_passes = cast (
319+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
320+ ).graph_module
321+ # If the graph is not modified, the same object should be returned
322+ assert original is graph_after_passes
323+ self .assertEqual (
324+ count_node (graph_after_passes , exir_ops .edge .aten .abs .default ), 1
325+ )
326+
307327 def test_remove_nop_select_before_view (self ) -> None :
308328 builder = GraphBuilder ()
309329 x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
0 commit comments