Skip to content

Commit 7150872

Browse files
committed
Add dynamic quant check before NHWC permute
1 parent cdd6f2d commit 7150872

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
1111
from executorch.backends.xnnpack.utils.utils import is_param_node
12-
from executorch.backends.xnnpack.utils.quant_utils import is_dequant
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import PassResult
1515

@@ -283,11 +283,11 @@ def input_to_nhwc(
283283
]
284284
else:
285285
# Need to create NHWC node
286-
# TODO: Replace with check to determine if dynamic quant
287-
input_is_dequant = is_dequant(input_node)
286+
# Check if input uses dynamic quantization
287+
is_dynamic_input = is_dynamic_qdq(input_node)
288288

289-
if input_is_dequant:
290-
# Trace back to find original source node
289+
if is_dynamic_input:
290+
# Trace back to original source node
291291
while getattr(input_node, "args", None):
292292
input_node = input_node.args[0]
293293

@@ -299,7 +299,7 @@ def input_to_nhwc(
299299
memory_format=torch.channels_last,
300300
)
301301

302-
if input_is_dequant:
302+
if is_dynamic_input:
303303
# Replace downstream input_nodes with NHWC node
304304
input_node.replace_all_uses_with(input_node_nhwc)
305305
input_node_nhwc.args = (input_node,)

0 commit comments

Comments
 (0)