Skip to content

Commit 2efe9bb

Browse files
committed
Use existing is_dequant check and update atol
1 parent f8f998c commit 2efe9bb

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
1111
from executorch.backends.xnnpack.utils.utils import is_param_node
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import PassResult
1415

15-
1616
# TODO(T151254305) use subgraph_rewriter
1717
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1818
"""
@@ -283,20 +283,12 @@ def input_to_nhwc(
283283
]
284284
else:
285285
# Need to create NHWC node
286-
# TODO: Best way to determine if trace back required?
287-
is_dequant = (
288-
input_node.op == "call_function"
289-
and getattr(input_node.target, "__name__", "")
290-
== "quantized_decomposed.dequantize_per_tensor.tensor"
291-
)
286+
# TODO: If input is dequant does that it's from dynamic quantization?
287+
input_is_dequant = is_dequant(input_node)
292288

293-
if is_dequant:
289+
if input_is_dequant:
294290
# Trace back to find original source node
295-
while (
296-
hasattr(input_node, "args")
297-
and isinstance(input_node.args, tuple)
298-
and len(input_node.args) > 0
299-
):
291+
while getattr(input_node, "args", None):
300292
input_node = input_node.args[0]
301293

302294
with graph_module.graph.inserting_after(input_node):
@@ -307,7 +299,7 @@ def input_to_nhwc(
307299
memory_format=torch.channels_last,
308300
)
309301

310-
if is_dequant:
302+
if input_is_dequant:
311303
# Replace downstream input_nodes with NHWC node
312304
for user in list(input_node.users):
313305
if user is not input_node_nhwc:

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,5 +766,5 @@ def get_inputs(self):
766766
model,
767767
model.get_inputs(),
768768
dynamic_shapes=None,
769-
atol=5e-2,
769+
atol=3.0,
770770
)

0 commit comments

Comments
 (0)