99import torch
1010from executorch .backends .xnnpack ._passes .xnnpack_pass import XNNPACKPass
1111from executorch .backends .xnnpack .utils .utils import is_param_node
12+ from executorch .backends .xnnpack .utils .quant_utils import is_dequant
1213from executorch .exir .dialects ._ops import ops as exir_ops
1314from executorch .exir .pass_base import PassResult
1415
15-
1616# TODO(T151254305) use subgraph_rewriter
1717class ChannelsLastTaggedReshapePass (XNNPACKPass ):
1818 """
@@ -283,20 +283,12 @@ def input_to_nhwc(
283283 ]
284284 else :
285285 # Need to create NHWC node
286- # TODO: Best way to determine if trace back required?
287- is_dequant = (
288- input_node .op == "call_function"
289- and getattr (input_node .target , "__name__" , "" )
290- == "quantized_decomposed.dequantize_per_tensor.tensor"
291- )
286+ # TODO: If input is dequant does that it's from dynamic quantization?
287+ input_is_dequant = is_dequant (input_node )
292288
293- if is_dequant :
289+ if input_is_dequant :
294290 # Trace back to find original source node
295- while (
296- hasattr (input_node , "args" )
297- and isinstance (input_node .args , tuple )
298- and len (input_node .args ) > 0
299- ):
291+ while getattr (input_node , "args" , None ):
300292 input_node = input_node .args [0 ]
301293
302294 with graph_module .graph .inserting_after (input_node ):
@@ -307,7 +299,7 @@ def input_to_nhwc(
307299 memory_format = torch .channels_last ,
308300 )
309301
310- if is_dequant :
302+ if input_is_dequant :
311303 # Replace downstream input_nodes with NHWC node
312304 for user in list (input_node .users ):
313305 if user is not input_node_nhwc :
0 commit comments