@@ -283,29 +283,32 @@ def input_to_nhwc(
283283 ]
284284 else :
285285 # Need to create NHWC node
286- source_node = input_node
287-
288- # TODO: safe/correct to always trace back?
289- # Trace back to find original source node
290- while (
291- hasattr (source_node , "args" )
292- and isinstance (source_node .args , tuple )
293- and len (source_node .args ) > 0
294- ):
295- source_node = source_node .args [0 ]
286+ # TODO: Best way to determine if trace back required?
287+ is_dequant = (
288+ input_node .op == "call_function"
289+ and getattr (input_node .target , "__name__" , "" )
290+ == "quantized_decomposed.dequantize_per_tensor.tensor"
291+ )
292+
293+ if is_dequant :
294+ # Trace back to find original source node
295+ while (
296+ hasattr (input_node , "args" )
297+ and isinstance (input_node .args , tuple )
298+ and len (input_node .args ) > 0
299+ ):
300+ input_node = input_node .args [0 ]
296301
297- with graph_module .graph .inserting_after (source_node ):
302+ with graph_module .graph .inserting_after (input_node ):
298303 input_node_nhwc = self .create_call_function_node (
299304 graph_module = graph_module ,
300305 target = exir_ops .edge .aten ._to_copy .default ,
301- args = (source_node ,),
306+ args = (input_node ,),
302307 memory_format = torch .channels_last ,
303308 )
304309
305- # If input_node was not the original source node
306- if source_node != input_node :
307- input_node = source_node
308- # Replace downstream source node with NHWC node
310+ if is_dequant :
311+ # Replace downstream input_nodes with NHWC node
309312 for user in list (input_node .users ):
310313 if user != input_node_nhwc :
311314 user .replace_input_with (input_node , input_node_nhwc )
0 commit comments