From 42923f3dad40154812df8d3982434e9576c3e500 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 24 Sep 2025 15:34:36 +0200 Subject: [PATCH 1/2] Cortex-M backend: Add mul and linear tests Minor included fixes: - Make quantized_linear_fusion_pass an XNNPACK pass to initialize it with an exported program - Add TO_EXECUTORCH as a valid stage after RUN_PASSES - Add ramp_tensor function to simplify creating dummy data Signed-off-by: Adrian Lundell Change-Id: Id13be6427390483aa1df1b76fc363ae4d0eae876 --- .../passes/quantized_linear_fusion_pass.py | 9 +- backends/cortex_m/test/ops/test_add.py | 22 +- backends/cortex_m/test/ops/test_linear.py | 211 ++++++++++++++++++ backends/cortex_m/test/ops/test_mul.py | 131 +++++++++++ backends/cortex_m/test/tester.py | 16 +- backends/test/harness/tester.py | 6 + 6 files changed, 381 insertions(+), 14 deletions(-) create mode 100644 backends/cortex_m/test/ops/test_linear.py create mode 100644 backends/cortex_m/test/ops/test_mul.py diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py index 8f8a90eec2f..11a49beb2f4 100644 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_linear_fusion_pass.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,9 +20,10 @@ ) from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass from torch.fx import Node from torch.fx.passes.infra.pass_manager import PassResult @@ -29,7 +31,7 @@ logger.setLevel(logging.INFO) -class QuantizedLinearFusionPass(ExportPass): +class QuantizedLinearFusionPass(XNNPACKPass): """ Cortex-M backend pass that fuses quantized linear-like patterns. Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize @@ -44,8 +46,7 @@ class QuantizedLinearFusionPass(ExportPass): requires_exported_program = True def __init__(self, exported_program: ExportedProgram): - super().__init__() - self._exported_program = exported_program + super().__init__(exported_program) self.nodes_to_erase = [] def call(self, graph_module: torch.fx.GraphModule) -> PassResult: diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 10edacb5a11..b7b0ffcbfbc 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -6,7 +6,11 @@ import torch from executorch.backends.arm.test.common import parametrize -from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha @@ -80,19 +84,19 @@ class CortexMAlphaAdd(ModelAlpha): ), "self_rank_2_pos": McuTestCase( CortexMSelfAdd(), - (torch.linspace(0, 1000, 10).reshape((10, 1)),), + (ramp_tensor(0, 1000, (10, 1)),), ), "self_rank_3_neg": McuTestCase( CortexMSelfAdd(), - (torch.linspace(-100, 0, 8).reshape((2, 2, 2)),), + (ramp_tensor(-100, 0, (2, 2, 2)),), ), "self_rank_4_small": McuTestCase( CortexMSelfAdd(), - (torch.linspace(-0.1, 0.1, 16).reshape(2, 2, 2, 2),), + (ramp_tensor(-0.1, 0.1, (2, 2, 2, 2)),), ), "self_rank_5": McuTestCase( CortexMSelfAdd(), - (torch.linspace(-5, 5, 32).reshape(2, 2, 2, 2, 2),), + (ramp_tensor(-5, 5, (2, 2, 2, 2, 2)),), ), "scalar_scalar": McuTestCase( CortexMScalarAdd(), @@ -117,15 +121,15 @@ class CortexMAlphaAdd(ModelAlpha): "broadcast_3": McuTestCase( CortexMTensorAdd(), ( - torch.linspace(-2, 2, 4).reshape(2, 1, 2, 1), - torch.linspace(-5, 5, 4).reshape(1, 2, 1, 2), + ramp_tensor(-2, 2, (2, 1, 2, 1)), + ramp_tensor(-5, 5, (1, 2, 1, 2)), ), ), "alpha": McuTestCase( CortexMAlphaAdd(0.5), ( - torch.linspace(-10, 10, 20).reshape(4, 5), - torch.linspace(-20, 20, 20).reshape(4, 5), + ramp_tensor(-10, 10, (4, 5)), + ramp_tensor(-20, 20, (4, 5)), ), ), } diff --git a/backends/cortex_m/test/ops/test_linear.py b/backends/cortex_m/test/ops/test_linear.py new file mode 100644 index 00000000000..a1275352fcf --- /dev/null +++ b/backends/cortex_m/test/ops/test_linear.py @@ -0,0 +1,211 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMMm(torch.nn.Module): + def forward(self, x, y): + return torch.mm(x, y) + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mm_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMBmm(torch.nn.Module): + def forward(self, x, y): + return torch.bmm(x, y) + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_bmm_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMAddmm(torch.nn.Module): + def forward(self, x, y, z, alpha=None, beta=None): + return torch.addmm(beta, x, alpha, y, z) + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMAt(CortexMMm): + def forward(self, x, y): + return x @ y + + +class CortexMMatmul(CortexMMm): + def forward(self, x, y): + return torch.matmul(x, y) + + +class CortexMLinear(CortexMMatmul): + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + + def forward(self, x): + return self.linear(x) + + +class CortexMLinearBias(CortexMAddmm): + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=True) + + def forward(self, x): + return self.linear(x) + + +test_cases = { + "mm": McuTestCase( + model=CortexMMm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "bmm": McuTestCase( + model=CortexMBmm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16, 16)), + ramp_tensor(0, 10, (1, 16, 16)), + ), + ), + "addmm": McuTestCase( + model=CortexMAddmm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ramp_tensor(0, 10, (16, 16)), + 2, + 4, + ), + ), + "addmm_scalars": McuTestCase( + model=CortexMAddmm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "@-operator": McuTestCase( + model=CortexMAt(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "matmul": McuTestCase( + model=CortexMMatmul(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "linear_rank1": McuTestCase( + model=CortexMLinear(2, 3), + example_inputs=(ramp_tensor(-1, 1, (2,)),), + ), + "linear_rank2_pos": McuTestCase( + model=CortexMLinear(8, 3), + example_inputs=(ramp_tensor(0, 10, (2, 8)),), + ), + "linear_rank3_neg": McuTestCase( + model=CortexMLinear(5, 3), + example_inputs=(ramp_tensor(-40, 0, (4, 2, 5)),), + ), + "linear_rank4": McuTestCase( + model=CortexMLinear(16, 32), + example_inputs=(ramp_tensor(-100, 100, (2, 1, 2, 16)),), + ), + "linear_rank5": McuTestCase( + model=CortexMLinear(4, 3), + example_inputs=(ramp_tensor(-2, 2, (5, 2, 1, 2, 4)),), + ), + "linear_bias": McuTestCase( + model=CortexMLinearBias(61, 37), + example_inputs=(ramp_tensor(0, 10, (8, 61)),), + ), +} + +dialect_xfails = { + "mm": ("torch.mm ops are currently not quantized", RuntimeError), + "bmm": ("torch.bmm ops are currently not quantized", RuntimeError), + "addmm": ("torch.addmm ops are currently not quantized", RuntimeError), + "addmm_scalars": ("torch.addmm ops are currently not quantized", RuntimeError), + "matmul": ("torch.matmul ops are currently not quantized", RuntimeError), + "@-operator": ("@ ops are currently not quantized", RuntimeError), + "linear_rank1": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_rank2_pos": ("name 'int32' is not defined", NameError), + "linear_rank3_neg": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_rank4": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_rank5": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_bias": ("name 'int32' is not defined", NameError), +} + + +@parametrize("test_case", test_cases, dialect_xfails) +def test_dialect_linear(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +implementation_xfails = { + "mm": ("torch.mm ops are currently not quantized", RuntimeError), + "bmm": ("torch.bmm ops are currently not quantized", RuntimeError), + "addmm": ("torch.addmm ops are currently not quantized", RuntimeError), + "addmm_scalars": ("torch.addmm ops are currently not quantized", RuntimeError), + "matmul": ("torch.matmul ops are currently not quantized", RuntimeError), + "@-operator": ("@ ops are currently not quantized", RuntimeError), + "linear_rank1": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_rank2_pos": ("Output 0 does not match reference output.", AssertionError), + "linear_rank3_neg": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_rank4": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_rank5": ("Only rank 2 linear ops are fused currently", RuntimeError), + "linear_bias": ("Output 0 does not match reference output.", AssertionError), +} + + +@parametrize("test_case", test_cases, implementation_xfails) +def test_implementation_linear(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py new file mode 100644 index 00000000000..a2f13760bf0 --- /dev/null +++ b/backends/cortex_m/test/ops/test_mul.py @@ -0,0 +1,131 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) +from executorch.backends.test.suite.operators.test_mul import Model + + +class CortexMSelfMul(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return x * x + + +class CortexMScalarMul(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMTensorMul(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +test_cases = { + "self_scalar": McuTestCase( + CortexMSelfMul(), + (10.0,), + ), + "self_rank_1": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-5, 5, (10,)),), + ), + "self_rank_2_pos": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(0, 1000, (10, 1)),), + ), + "self_rank_3_neg": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-100, 0, (2, 2, 2)),), + ), + "self_rank_4_small": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-0.1, 0.1, (2, 2, 2, 2)),), + ), + "self_rank_5": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-5, 5, (2, 2, 2, 2, 2)),), + ), + "scalar_scalar": McuTestCase( + CortexMScalarMul(), + (-0.5, 1.0), + ), + "tensor_scalar": McuTestCase( + CortexMScalarMul(), + (torch.ones(2, 2), 1.0), + ), + "scalar_tensor": McuTestCase( + CortexMScalarMul(), + (1000.0, torch.ones(2, 2)), + ), + "broadcast_1": McuTestCase( + CortexMTensorMul(), + (torch.ones(1), torch.ones(2, 2, 2, 2)), + ), + "broadcast_2": McuTestCase( + CortexMTensorMul(), + (torch.ones((2, 1, 1, 1)), torch.ones(1)), + ), + "broadcast_3": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-2, 2, (2, 1, 2, 1)), + ramp_tensor(-5, 5, (1, 2, 1, 2)), + ), + ), +} + + +@pytest.mark.skip(reason="Not implemented yet") +@parametrize("test_case", test_cases) +def test_dialect_mul(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@pytest.mark.skip(reason="Not implemented yet") +@parametrize("test_case", test_cases) +def test_implementation_mul(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index 8af31e58cd7..71412a9c475 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -15,6 +15,9 @@ ) from executorch.backends.arm.test.common import get_u55_compile_spec from executorch.backends.arm.test.tester.arm_tester import Serialize +from executorch.backends.cortex_m.passes.quantized_linear_fusion_pass import ( + QuantizedLinearFusionPass, +) from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( QuantizedOpFusionPass, ) @@ -44,7 +47,12 @@ def __init__(self): class CortexMRunPasses(RunPasses): def __init__(self): super().__init__( - XNNPACKPassManager, pass_list=[QuantizedOpFusionPass, ReplaceQuantNodesPass] + XNNPACKPassManager, + pass_list=[ + ReplaceQuantNodesPass, + QuantizedLinearFusionPass, + QuantizedOpFusionPass, + ], ) @@ -98,3 +106,9 @@ def test_implementation(self, qtol=0): class McuTestCase: model: torch.nn.Module example_inputs: tuple[Any] + + +def ramp_tensor(start: int, end: int, shape: tuple[int]) -> torch.Tensor: + return torch.linspace(start, end, steps=torch.prod(torch.tensor(shape))).reshape( + shape + ) diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index 351bab4a605..02c6fc4c82d 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -1,3 +1,8 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import random from collections import Counter, OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple @@ -62,6 +67,7 @@ def __init__( StageType.RUN_PASSES: [ StageType.PARTITION, StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, ], # TODO Make this Stage optional StageType.PARTITION: [StageType.TO_EXECUTORCH], From 048b0c0780904eb71605a2685392d6e22b6acfc5 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Thu, 9 Oct 2025 17:50:40 +0200 Subject: [PATCH 2/2] Correct import path Signed-off-by: Adrian Lundell --- backends/cortex_m/test/tester.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index 71412a9c475..c492d3c8443 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -8,11 +8,6 @@ from typing import Any import torch - -from backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, - XNNPACKQuantizer, -) from executorch.backends.arm.test.common import get_u55_compile_spec from executorch.backends.arm.test.tester.arm_tester import Serialize from executorch.backends.cortex_m.passes.quantized_linear_fusion_pass import ( @@ -36,6 +31,11 @@ ) from executorch.backends.xnnpack._passes import XNNPACKPassManager +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + class CortexMQuantize(Quantize): def __init__(self):