From 94365226b4265a34c8e9c7c588a5042f04015151 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 17 Nov 2025 10:36:37 +0100 Subject: [PATCH] Cortex-M backend: Fuse Relu, Hardtanh and Hardsigmoid Implements a new pass which fuses activation passes with preceeding cortex-m ops if possible. Removed quantization of conv1d, conv3d as they are not tested + moves Conv+relu test to test_activations. Propagate qmin, qmax to conv kernel. Change-Id: Ic7d5709d72a3d8254aff4455bc6dc8eafda14801 Signed-off-by: Adrian Lundell --- backends/cortex_m/passes/__init__.py | 1 + .../cortex_m/passes/activation_fusion_pass.py | 170 ++++++++ .../passes/convert_to_cortex_m_pass.py | 11 +- .../cortex_m/passes/cortex_m_pass_manager.py | 2 + .../cortex_m/quantizer/operator_configs.py | 13 +- backends/cortex_m/test/ops/test_activation.py | 409 ++++++++++++++++++ backends/cortex_m/test/ops/test_conv.py | 35 -- backends/cortex_m/test/tester.py | 8 +- 8 files changed, 607 insertions(+), 42 deletions(-) create mode 100644 backends/cortex_m/passes/activation_fusion_pass.py create mode 100644 backends/cortex_m/test/ops/test_activation.py diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index d1bb580d871..5aeb60be514 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .activation_fusion_pass import ActivationFusionPass # noqa from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py new file mode 100644 index 00000000000..b200348cc9d --- /dev/null +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -0,0 +1,170 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import executorch.backends.cortex_m.ops.operators # noqa: F401 +from executorch.backends.arm._passes.quant_args import QuantArgs + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class ActivationFusionPass(ExportPass): + """Fuse activations into preceding Cortex-M quantized operators. + + Supported activation patterns: + q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq + + Fusing works by clamping the quantized output range (and zero-point when + required) of the preceding Cortex-M operator, then removing the activation + node from the graph. + """ + + TARGETS = { + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardsigmoid.default, + } + + FUSE_OPS = { + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.convolution.default, + } + + def _quantize(self, val, scale, zp, qmin, qmax): + return min(max(round(val / scale + zp), qmin), qmax) + + def _get_validated_qparams(self, node, input_node): + + if "input_qparams" not in input_node.meta or "output_qparams" not in node.meta: + logger.warning( + f"Cannot fuse activation for {input_node.name}->{node.name} as the pattern wasn't quantized properly." + ) + return None + + qparams_dict = node.meta["output_qparams"][0]._asdict() + zp = qparams_dict["zp"] + scale = qparams_dict["scale"] + qmin = qparams_dict["qmin"] + qmax = qparams_dict["qmax"] + + if not isinstance(scale, float) or not isinstance(zp, int): + logger.warning( + f"Cannot fuse activation {node.name} as quantization parameters are not per tensor." + ) + return None + + match node.target: + case exir_ops.edge.aten.relu.default: + quantized_min_val = self._quantize(0, scale, zp, qmin, qmax) + quantized_max_val = qmax + case exir_ops.edge.aten.hardtanh.default: + quantized_min_val = self._quantize(node.args[1], scale, zp, qmin, qmax) + quantized_max_val = self._quantize(node.args[2], scale, zp, qmin, qmax) + case exir_ops.edge.aten.hardsigmoid.default: + quantized_min_val = self._quantize(0, scale, zp, qmin, qmax) + quantized_max_val = self._quantize(1, scale, zp, qmin, qmax) + case _: + raise RuntimeError("Unexpected target {node.target}.") + + # If the minimal quantized value is larger than the qmin, it means that the quantized range contains + # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. + if qparams_dict["qmin"] != quantized_min_val: + logger.warning( + f"Cannot fuse activation {node.name} as qmin is out of range." + ) + return None + + # If the maximal quantized value is smaller than the qmax, it means that the quantized range contains + # invalid values [quantized_max_val + 1, ... , qmax], indicating bad quantization parameters. + if quantized_max_val != qparams_dict["qmax"]: + logger.warning( + f"Cannot fuse activation {node.name} as qmax is out of range." + ) + return None + + return qparams_dict + + def _update_qparams_hardsigmoid(self, quant_dict): + """ + Returns quant_dict with scale and zp updated to match hardsigmoid activation. + + The quantized output from the hard sigmoid is defined by + Q(y) = clamp(round(y/scale + zp), qmin, qmax) + y = clamp(x/6 + 1/2, 0, 1) + where x is the output of the fused activation op, conv or linear. + + Q(y) can be rewritten as a function of only x: + Q(y) = clamp(round(clamp(x/6 + 1/2, 0, 1)/scale + zp), qmin, qmax) + Q(y) = clamp(round(clamp((x/(6*scale) + 1/(2*scale) + zp, zp, 1/scale + zp)), qmin, qmax) + + From definition of the qparams mapping the output in the range [0,1] to quantized range + [qmin, qmax], we have: + zp = Q(0) <= qmin + 1/scale + zp = Q(1) >= qmax + which makes the inner clamp redundant. + + Therefore, hardsigmoid is equivalent to a quantization with modified parameters + new_scale := 6*scale + new_zp = zp + 1/(2*scale) ~= zp + round(1/(2*scale)) + """ + + new_scale = quant_dict["scale"] * 6 + + new_zp = quant_dict["zp"] + round(1 / (2 * quant_dict["scale"])) + clamped_new_zp = max(quant_dict["qmin"], min(quant_dict["qmax"], new_zp)) + + quant_dict["scale"] = new_scale + quant_dict["zp"] = clamped_new_zp + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + nodes_to_erase: list[Node] = [] + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + + input_node = node.args[0] + if ( + input_node.op != "call_function" + or input_node.target not in self.FUSE_OPS + ): + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} is not a supported fused activation op." + ) + continue + if len(input_node.users.values()) > 1: + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} has multiple users." + ) + continue + + if (qparams_dict := self._get_validated_qparams(node, input_node)) is None: + continue + + if node.target == exir_ops.edge.aten.hardsigmoid.default: + self._update_qparams_hardsigmoid(qparams_dict) + + input_node.meta["output_qparams"][0] = QuantArgs(**qparams_dict) + + node.replace_all_uses_with(input_node) + nodes_to_erase.append(node) + modified = True + + for node in nodes_to_erase: + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index c849b2949bf..721a1951753 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -139,8 +139,11 @@ def _get_convolution_replacement(self, node) -> int: if not isinstance(weight_scales, list): weight_scales = [weight_scales] * weight.data.shape[0] - output_scale = node.meta["output_qparams"][0].scale - output_zero_point = node.meta["output_qparams"][0].zp + output_qparams = node.meta["output_qparams"][0] + output_scale = output_qparams.scale + output_zero_point = output_qparams.zp + output_qmin = output_qparams.qmin + output_qmax = output_qparams.qmax quantized_multipliers = [] quantized_shifts = [] @@ -177,8 +180,8 @@ def _get_convolution_replacement(self, node) -> int: output_zero_point, torch.tensor(quantized_multipliers, dtype=torch.int32), torch.tensor(quantized_shifts, dtype=torch.int32), - -128, - 127, + output_qmin, + output_qmax, ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 948a60121b4..fd89986cef0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -11,6 +11,7 @@ ScalarsToAttributePass, ) from executorch.backends.cortex_m.passes import ( + ActivationFusionPass, ConvertToCortexMPass, QuantizedOpFusionPass, ReplaceQuantNodesPass, @@ -31,6 +32,7 @@ class CortexMPassManager(PassManager): ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, + ActivationFusionPass, ConvertToCortexMPass, ] diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py index c6b15fb9a78..25d3626a147 100644 --- a/backends/cortex_m/quantizer/operator_configs.py +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -24,12 +24,21 @@ LINEAR_OP_PATTERNS = [ [torch.ops.aten.linear.default], [torch.ops.aten.linear.default, torch.ops.aten.relu.default], + [torch.ops.aten.linear.default, torch.ops.aten.relu_.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardtanh.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardtanh_.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid_.default], ] CONV_OP_PATTERNS = [ - [torch.ops.aten.conv1d.default], [torch.ops.aten.conv2d.default], - [torch.ops.aten.conv3d.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid_.default], ] # ----------------- OPERATOR CONFIG PRESETS ----------------- diff --git a/backends/cortex_m/test/ops/test_activation.py b/backends/cortex_m/test/ops/test_activation.py new file mode 100644 index 00000000000..bc20d364674 --- /dev/null +++ b/backends/cortex_m/test/ops/test_activation.py @@ -0,0 +1,409 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMLinearReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=4, out_features=3): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + +class CortexMLinearHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.25, max_val=0.75): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearReLU6(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(self.linear(x)) + + +class CortexMLinearReLUInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.linear(x)) + + +class CortexMLinearHardtanhInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-1.0, max_val=1.0): + super().__init__() + self.linear = torch.nn.Linear(8, 8, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=True) + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearHardsigmoid(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=6, out_features=6): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.act = torch.nn.Hardsigmoid() + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMConv2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + +class CortexMConv2DReLU6(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, stride=2, padding=1, bias=False) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(self.conv(x)) + + +class CortexMConv2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-2.0, max_val=2.0): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=True) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DReLUInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.conv(x)) + + +class CortexMConv2DHardtanhInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=True) + torch.nn.init.ones_(self.conv.weight) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DHardsigmoid(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1, bias=False) + self.act = torch.nn.Hardsigmoid(inplace=True) + self.conv.weight.data.fill_(1) + + def forward(self, x): + return self.act(self.conv(x)) + + +test_cases = { + # Linear + activation tests with various data ranges + "linear_relu_small_range": McuTestCase( + model=CortexMLinearReLU(), + example_inputs=(ramp_tensor(-10, 10, (1, 4)),), + ), + "linear_relu_large_range": McuTestCase( + model=CortexMLinearReLU(in_features=16, out_features=16), + example_inputs=(ramp_tensor(-100, 100, (2, 16)),), + ), + "linear_relu_negative": McuTestCase( + model=CortexMLinearReLU(in_features=8, out_features=8), + example_inputs=(ramp_tensor(-50, 0, (1, 8)),), + ), + "linear_relu6": McuTestCase( + model=CortexMLinearReLU6(), + example_inputs=(ramp_tensor(-2, 10, (1, 8)),), + ), + "linear_relu_inplace": McuTestCase( + model=CortexMLinearReLUInplace(), + example_inputs=(ramp_tensor(-5, 5, (2, 8)),), + ), + "linear_hardtanh_symmetric": McuTestCase( + model=CortexMLinearHardtanh(min_val=-0.5, max_val=0.5), + example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),), + ), + "linear_hardtanh_asymmetric": McuTestCase( + model=CortexMLinearHardtanh(min_val=-1.5, max_val=0.25), + example_inputs=(ramp_tensor(-2, 1, (1, 4)),), + ), + "linear_hardtanh_large_range": McuTestCase( + model=CortexMLinearHardtanh(min_val=-10.0, max_val=10.0), + example_inputs=(ramp_tensor(-20, 20, (2, 4)),), + ), + "linear_hardtanh_inplace": McuTestCase( + model=CortexMLinearHardtanhInplace(min_val=-0.75, max_val=0.75), + example_inputs=(ramp_tensor(-2, 2, (1, 8)),), + ), + # Convolution + activation tests with various configurations + "conv2d_relu_small_kernel": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu_large_range": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-50, 50, (2, 4, 16, 16)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu6_stride": McuTestCase( + model=CortexMConv2DReLU6(), + example_inputs=( + ramp_tensor(-10, 20, (1, 3, 12, 12)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu_inplace": McuTestCase( + model=CortexMConv2DReLUInplace(), + example_inputs=( + ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_narrow": McuTestCase( + model=CortexMConv2DHardtanh(min_val=-0.5, max_val=0.5), + example_inputs=( + ramp_tensor(-2, 2, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_wide": McuTestCase( + model=CortexMConv2DHardtanh(min_val=-5.0, max_val=5.0), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_inplace": McuTestCase( + model=CortexMConv2DHardtanhInplace(min_val=-10.0, max_val=10.0), + example_inputs=( + ramp_tensor(-15, 15, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "linear_hardsigmoid": McuTestCase( + model=CortexMLinearHardsigmoid(in_features=6, out_features=4), + example_inputs=(ramp_tensor(-8, 8, (2, 6)),), + ), + "conv2d_hardsigmoid_inplace": McuTestCase( + model=CortexMConv2DHardsigmoid(), + example_inputs=( + ramp_tensor(-4, 4, (1, 1, 6, 6)).to(memory_format=torch.channels_last), + ), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_activation(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize("test_case", test_cases) +def test_implementation_activation(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py index c6bb4815dca..8a67d1b7de1 100644 --- a/backends/cortex_m/test/ops/test_conv.py +++ b/backends/cortex_m/test/ops/test_conv.py @@ -112,33 +112,6 @@ def forward(self, x): return x -class CortexMConv2DReLU(torch.nn.Module): - ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, - "executorch_exir_dialects_edge__ops_aten_relu_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, - } - - ops_after_transforms = { - "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, - "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_aten_relu_default": 1, - } - - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=True) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.conv(x) - x = self.relu(x) - return x - - # in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode test_cases = { "conv2d": McuTestCase( @@ -205,12 +178,6 @@ def forward(self, x): ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), ), ), - "conv2d_relu": McuTestCase( - model=CortexMConv2DReLU(), - example_inputs=( - ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), - ), - ), } @@ -219,7 +186,6 @@ def forward(self, x): "conv1d": "Currently not supported.", "conv2d_nchw": "Currently not supported.", "conv3d": "Currently not supported.", - "conv2d_relu": "Currently not supported.", } @@ -237,7 +203,6 @@ def test_dialect_conv2d(test_case): "conv1d": "Currently not supported.", "conv2d_nchw": "Currently not supported.", "conv3d": "Currently not supported.", - "conv2d_relu": "Currently not supported.", } diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index 010cc7e4ace..70f91b3f1dc 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -34,7 +34,13 @@ def __init__(self): class CortexMToEdge(ToEdge): def __init__(self): - config = EdgeCompileConfig(preserve_ops=[torch.ops.aten.linear.default]) + config = EdgeCompileConfig( + preserve_ops=[ + torch.ops.aten.linear.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + ] + ) super().__init__(config)