diff --git a/backends/arm/_passes/tag_io_quant_pass.py b/backends/arm/_passes/tag_io_quant_pass.py index 2fce6cf3fd4..49990d8e5fd 100644 --- a/backends/arm/_passes/tag_io_quant_pass.py +++ b/backends/arm/_passes/tag_io_quant_pass.py @@ -43,9 +43,9 @@ def call(self, graph_module: torch.fx.GraphModule): # tag dq of outputs if node.op == "output": - quant, *_ = node.args[0] - if self.is_dequant_node(quant): - quant.meta["arm_override_partition"] = False + for quant in node.args[0]: + if self.is_dequant_node(quant): + quant.meta["arm_override_partition"] = False graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/arm/test/passes/test_tag_io_quant_pass.py b/backends/arm/test/passes/test_tag_io_quant_pass.py index 9f292bb7caa..639bf478bcf 100644 --- a/backends/arm/test/passes/test_tag_io_quant_pass.py +++ b/backends/arm/test/passes/test_tag_io_quant_pass.py @@ -12,13 +12,13 @@ from executorch.backends.arm.test.tester.arm_tester import ArmTester -class Add(torch.nn.Module): +class TwoInputsTwoOutputs(torch.nn.Module): def get_inputs(self): - return (torch.rand(1, 10, 10, 10),) + return (torch.rand(1, 10, 10, 10), (torch.rand(1, 10, 10, 10))) - def forward(self, x): - return x + x + def forward(self, x, y): + return (x + y, x * y) class TestTagIOQuantPass(unittest.TestCase): @@ -36,29 +36,29 @@ def _tosa_BI_u55_pipeline(self, module: torch.nn.Module): .to_edge() .check_count( { - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2 + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4 } ) .check_count( { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2 + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 6 } ) .partition() .check_count( { - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1 + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2 } ) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_count( { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1 + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2 } ) # .to_executorch() requires additional steps ) def test_BI_u55_artifact(self): - model = Add() + model = TwoInputsTwoOutputs() self._tosa_BI_u55_pipeline(model)