@@ -191,6 +191,11 @@ class RemoveNopSliceOrViewOpPass(ExportPass):
191191 Remove slice ops that are more like views, and view ops that do not change the shape
192192 """
193193
194+ @staticmethod
195+ def _is_nop (input_shape : tuple [int , ...], output_shape : tuple [int , ...]) -> bool :
196+ # If both arg_shape and out_shape are the same, this slice is a nop
197+ return input_shape == output_shape
198+
194199 def call_operator (
195200 self ,
196201 op , # pyre-ignore
@@ -207,13 +212,27 @@ def call_operator(
207212 arg0 = cast (ProxyValue , args [0 ])
208213 out_shape = meta ["val" ].shape
209214
210- # If both arg_shape and out_shape are the same, this slice is a nop
211215 return (
212216 arg0
213- if arg0 .to_tensor ().shape == out_shape
217+ if self . _is_nop ( arg0 .to_tensor ().shape , out_shape )
214218 else super ().call_operator (op , args , kwargs , meta )
215219 )
216220
221+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
222+ for target in [
223+ exir_ops .edge .aten .slice_copy .Tensor ,
224+ exir_ops .edge .aten .view_copy .default ,
225+ ]:
226+ for node in graph_module .graph .find_nodes (
227+ op = "call_function" , target = target
228+ ):
229+ input_node = node .args [0 ]
230+ assert isinstance (input_node , torch .fx .Node )
231+ if self ._is_nop (input_node .meta ["val" ].shape , node .meta ["val" ].shape ):
232+ return super ().call (graph_module )
233+
234+ return PassResult (graph_module , False )
235+
217236
218237@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
219238class RemoveNopLinalgVectorNormOpPass (ExportPass ):
0 commit comments