Skip to content

Commit 8be6327

Browse files
Arm backend: make tag_io_quant_pass tag all output dq-nodes (pytorch#6547)
Make tag_io_quant_pass tag all output dq-nodes Change-Id: I8dba320be9c1be65db28227fe81b83d4514fd825 Signed-off-by: Oscar Andersson <[email protected]>
1 parent ffad824 commit 8be6327

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

backends/arm/_passes/tag_io_quant_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def call(self, graph_module: torch.fx.GraphModule):
4343

4444
# tag dq of outputs
4545
if node.op == "output":
46-
quant, *_ = node.args[0]
47-
if self.is_dequant_node(quant):
48-
quant.meta["arm_override_partition"] = False
46+
for quant in node.args[0]:
47+
if self.is_dequant_node(quant):
48+
quant.meta["arm_override_partition"] = False
4949

5050
graph_module.recompile()
5151
return PassResult(graph_module, True)

backends/arm/test/passes/test_tag_io_quant_pass.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1313

1414

15-
class Add(torch.nn.Module):
15+
class TwoInputsTwoOutputs(torch.nn.Module):
1616

1717
def get_inputs(self):
18-
return (torch.rand(1, 10, 10, 10),)
18+
return (torch.rand(1, 10, 10, 10), (torch.rand(1, 10, 10, 10)))
1919

20-
def forward(self, x):
21-
return x + x
20+
def forward(self, x, y):
21+
return (x + y, x * y)
2222

2323

2424
class TestTagIOQuantPass(unittest.TestCase):
@@ -36,29 +36,29 @@ def _tosa_BI_u55_pipeline(self, module: torch.nn.Module):
3636
.to_edge()
3737
.check_count(
3838
{
39-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2
39+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4
4040
}
4141
)
4242
.check_count(
4343
{
44-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2
44+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 6
4545
}
4646
)
4747
.partition()
4848
.check_count(
4949
{
50-
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1
50+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2
5151
}
5252
)
5353
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
5454
.check_count(
5555
{
56-
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1
56+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2
5757
}
5858
)
5959
# .to_executorch() requires additional steps
6060
)
6161

6262
def test_BI_u55_artifact(self):
63-
model = Add()
63+
model = TwoInputsTwoOutputs()
6464
self._tosa_BI_u55_pipeline(model)

0 commit comments

Comments
 (0)