From 4ad83fa15e637741a94aa665aa753d3547322914 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 24 Oct 2025 11:59:57 +0200 Subject: [PATCH] Arm backend: Avoid not decomposing linears we reject If a linear is not quantized properly, we will reject it when partitioning. However, if we tell Executorch to _not_ not decompose an op, we are required to partition it. We thus need to figure out if we will partition the linear or not in the ops_not_to_decompose filter function. Also turn off grad in the arm tester to solve an error that popped up in the GRU model. Since we only do inference, grad is never relevant. Signed-off-by: Erik Lundell Change-Id: Iaef8671cc89b290ec539ad236fb3450a3dde7c73 --- backends/arm/test/models/test_nn_modules.py | 2 - backends/arm/test/tester/arm_tester.py | 1 + backends/arm/tosa/partitioner.py | 41 ++++++++++++++------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 158a8a587e2..8192ec6887b 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -137,8 +137,6 @@ def test_nn_Modules_FP(test_data): "test_data", test_parameters, xfails={ - "GRUModule": "RuntimeError: Node aten_linear_default with op was not decomposed or delegated.", - "PReLUModule": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", "TransformerModule": "AssertionError: Output 0 does not match reference output.", }, ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7be249609b0..604253b6c92 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -266,6 +266,7 @@ def __init__( StageType.QUANTIZE, StageType.EXPORT, ] + self.original_module.requires_grad_(False) # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. self.stages[StageType.INITIAL_MODEL] = None diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 1ae743101b6..850e607bef2 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -338,6 +338,7 @@ def ops_to_not_decompose( ops_to_not_decompose_if_quant_op = [ torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, + torch.ops.aten.linear.default, ] def filter_fn(node: torch.fx.Node) -> bool: @@ -355,31 +356,45 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ - dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default - q = torch.ops.quantized_decomposed.quantize_per_tensor.default + dq = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + q = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ) if node.target in ops_to_not_decompose_if_quant_op: # Assume we should not decompose the operator (it is quantized) - should_not_decompose = True + correct_output_quant = True + correct_input_quant = True input_nodes = node.all_input_nodes - ouput_nodes = node.users + output_nodes = node.users for inp in input_nodes: - if inp.target != dq: - should_not_decompose = False - - for out in ouput_nodes: - if out.target != q: - should_not_decompose = False - - return should_not_decompose + if inp.target not in dq: + correct_input_quant = False + + for out in output_nodes: + if out.target not in q: + correct_output_quant = False + # In some cases, a linear is quantized together with its activation. + if ( + node.target == torch.ops.aten.linear.default + and len(output_nodes) == 1 + and list(output_nodes)[0].target + in (torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default) + ): + correct_output_quant = True + + return correct_input_quant and correct_output_quant # By default, do not decompose the operator return True ops_to_not_decompose = [ - torch.ops.aten.linear.default, torch.ops.aten.eye.default, torch.ops.aten.linspace.default, torch.ops.aten.logit.default,