@@ -47,7 +47,7 @@ class RemoveNoopPass(ExportPass):
4747 """
4848 Removes noops that pass through arguments.
4949 """
50-
50+
5151 def call (self , graph_module : GraphModule ) -> PassResult :
5252
5353 # In this list we'll collect all the dequant nodes that are inputs to ops that
@@ -56,35 +56,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
5656 dequant_nodes = []
5757
5858 for node in graph_module .graph .nodes :
59- if node .op != "call_function" :
60- continue
61-
62- if node .target not in (
63- torch .ops .aten .to .dtype ,
64- torch .ops .aten .dropout .default ,
65- torch .ops .aten .slice_copy .Tensor ,
66- ):
67- continue
68-
69- orig_tensor = node .args [0 ].meta ["val" ]
70-
71- if orig_tensor is node .meta ["val" ]:
72- # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
73- # Otherwise, removing only the op will suffice.
59+ if RemoveNoopPass ._should_remove_node (node ):
7460 if node .args [0 ].target in _DEQUANT_OPS :
7561 dequant_nodes += [node .args [0 ]]
7662 node .replace_all_uses_with (node .args [0 ])
77- continue
78-
79- if node .target == torch .ops .aten .slice_copy .Tensor :
80- # Only do this check if all the dims are static.
81- if all (isinstance (dim , int ) for dim in orig_tensor .size ()):
82- if orig_tensor .shape == node .meta ["val" ].shape :
83- # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
84- # Otherwise, removing only the op will suffice.
85- if node .args [0 ].target in _DEQUANT_OPS :
86- dequant_nodes += [node .args [0 ]]
87- node .replace_all_uses_with (node .args [0 ])
8863
8964 graph_module .graph .eliminate_dead_code ()
9065 eliminate_dq_q (graph_module , dequant_nodes )
@@ -93,6 +68,29 @@ def call(self, graph_module: GraphModule) -> PassResult:
9368
9469 return PassResult (graph_module , True )
9570
71+ @staticmethod
72+ def _should_remove_node (node : torch .fx .Node ) -> bool :
73+ if node .op != "call_function" :
74+ return False
75+
76+ input_meta_val = node .args [0 ].meta .get ("val" , None ) if len (node .args ) > 0 and hasattr (node .args [0 ], "meta" ) else None
77+
78+ if input_meta_val is not None :
79+ if node .target in (
80+ torch .ops .aten .to .dtype ,
81+ torch .ops .aten .dropout .default ,
82+ ):
83+ return input_meta_val is node .meta ["val" ]
84+ elif node .target == torch .ops .aten .slice_copy .Tensor :
85+ # Only do this check if all the dims are static.
86+ return all (isinstance (dim , int ) for dim in input_meta_val .size ()) and input_meta_val .shape == node .meta ["val" ].shape
87+ elif node .target == torch .ops .aten .clone .default :
88+ # Remove if memory_format=None, preserve_format, or input already has the target memory format.
89+ dest_memory_format = node .kwargs .get ("memory_format" , None ) or torch .preserve_format
90+ return dest_memory_format == torch .preserve_format or input_meta_val .is_contiguous (memory_format = dest_memory_format )
91+
92+ return False
93+
9694
9795class RemoveToCopyPass (ExportPass ):
9896 """
0 commit comments