Skip to content

Commit 0fef04a

Browse files
committed
Add is_dequant check for trace back when inserting permute
1 parent 2905b98 commit 0fef04a

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -283,29 +283,32 @@ def input_to_nhwc(
283283
]
284284
else:
285285
# Need to create NHWC node
286-
source_node = input_node
287-
288-
# TODO: safe/correct to always trace back?
289-
# Trace back to find original source node
290-
while (
291-
hasattr(source_node, "args")
292-
and isinstance(source_node.args, tuple)
293-
and len(source_node.args) > 0
294-
):
295-
source_node = source_node.args[0]
286+
# TODO: Best way to determine if trace back required?
287+
is_dequant = (
288+
input_node.op == "call_function"
289+
and getattr(input_node.target, "__name__", "")
290+
== "quantized_decomposed.dequantize_per_tensor.tensor"
291+
)
292+
293+
if is_dequant:
294+
# Trace back to find original source node
295+
while (
296+
hasattr(input_node, "args")
297+
and isinstance(input_node.args, tuple)
298+
and len(input_node.args) > 0
299+
):
300+
input_node = input_node.args[0]
296301

297-
with graph_module.graph.inserting_after(source_node):
302+
with graph_module.graph.inserting_after(input_node):
298303
input_node_nhwc = self.create_call_function_node(
299304
graph_module=graph_module,
300305
target=exir_ops.edge.aten._to_copy.default,
301-
args=(source_node,),
306+
args=(input_node,),
302307
memory_format=torch.channels_last,
303308
)
304309

305-
# If input_node was not the original source node
306-
if source_node != input_node:
307-
input_node = source_node
308-
# Replace downstream source node with NHWC node
310+
if is_dequant:
311+
# Replace downstream input_nodes with NHWC node
309312
for user in list(input_node.users):
310313
if user != input_node_nhwc:
311314
user.replace_input_with(input_node, input_node_nhwc)

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def _test_dq_conv2d(
248248
)
249249

250250
tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes)
251-
tester = tester.quantize(Quantize(quantization_config=quant_config))
251+
tester.quantize(Quantize(quantization_config=quant_config))
252252
tester.export()
253253

254254
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])

0 commit comments

Comments
 (0)