diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 158a8a587e2..8192ec6887b 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -137,8 +137,6 @@ def test_nn_Modules_FP(test_data): "test_data", test_parameters, xfails={ - "GRUModule": "RuntimeError: Node aten_linear_default with op was not decomposed or delegated.", - "PReLUModule": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", "TransformerModule": "AssertionError: Output 0 does not match reference output.", }, ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7be249609b0..604253b6c92 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -266,6 +266,7 @@ def __init__( StageType.QUANTIZE, StageType.EXPORT, ] + self.original_module.requires_grad_(False) # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. self.stages[StageType.INITIAL_MODEL] = None diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 1ae743101b6..850e607bef2 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -338,6 +338,7 @@ def ops_to_not_decompose( ops_to_not_decompose_if_quant_op = [ torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, + torch.ops.aten.linear.default, ] def filter_fn(node: torch.fx.Node) -> bool: @@ -355,31 +356,45 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ - dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default - q = torch.ops.quantized_decomposed.quantize_per_tensor.default + dq = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + q = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ) if node.target in ops_to_not_decompose_if_quant_op: # Assume we should not decompose the operator (it is quantized) - should_not_decompose = True + correct_output_quant = True + correct_input_quant = True input_nodes = node.all_input_nodes - ouput_nodes = node.users + output_nodes = node.users for inp in input_nodes: - if inp.target != dq: - should_not_decompose = False - - for out in ouput_nodes: - if out.target != q: - should_not_decompose = False - - return should_not_decompose + if inp.target not in dq: + correct_input_quant = False + + for out in output_nodes: + if out.target not in q: + correct_output_quant = False + # In some cases, a linear is quantized together with its activation. + if ( + node.target == torch.ops.aten.linear.default + and len(output_nodes) == 1 + and list(output_nodes)[0].target + in (torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default) + ): + correct_output_quant = True + + return correct_input_quant and correct_output_quant # By default, do not decompose the operator return True ops_to_not_decompose = [ - torch.ops.aten.linear.default, torch.ops.aten.eye.default, torch.ops.aten.linspace.default, torch.ops.aten.logit.default,