From 8c09c9a3cc3052a730804a29b31d5a68823e19e2 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 8 Sep 2025 11:24:18 +0200 Subject: [PATCH 1/2] Cortex_m backend: Fix add implementation - Call CMSIS-NN kernel with correct argument order and signs - Change python implementation to reflect CMSIS-NN behaviour - Fix scale calculations - Remove broken broadcasting support - Add pass to lower scalar version ops - Remove unused definitions/ implementations in operators.py, operators.yaml and op_quantized_add.cpp Note: arm_elementwise_add_s8 does not natively support broadcasting, so simply resizing the output tensor will not work. Enabling this in an efficient way is not stragiht forward, so avoid fusing these ops for now to avoid break graphs. Signed-off-by: Adrian Lundell Change-Id: Id76db13848f2ce67d7527f40d31c06db663af8fa --- backends/cortex_m/ops/op_quantized_add.cpp | 45 ++------- backends/cortex_m/ops/operators.py | 93 ++++--------------- backends/cortex_m/ops/operators.yaml | 7 +- .../cortex_m/passes/cortex_m_pass_manager.py | 4 + backends/cortex_m/passes/passes_utils.py | 28 +++++- .../passes/quantized_op_fusion_pass.py | 23 ++++- backends/cortex_m/test/ops/test_add.py | 90 +++++++++++++----- 7 files changed, 138 insertions(+), 152 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index 044c2bd92d5..30be108ffcb 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -47,13 +48,6 @@ Tensor& quantized_add_out( output_shift, out); - // Broadcast if needed - auto result = resize_to_broadcast_target_size(input1_int8, input2_int8, out); - ET_CHECK_MSG( - (result == Error::Ok), - "Failed to resize output tensor. Status: [%d]", - result); - ET_LOG( Info, "quantized_add_out: input1_int8.sizes() = %zu", @@ -69,7 +63,7 @@ Tensor& quantized_add_out( int32_t output_mult = extractScalarToInt32(output_multiplier); int output_shift_val = extractScalarToInt(output_shift); - // Left shift to maximize precision (tune as needed) + // Left shift to maximize precision const int32_t left_shift = 20; const int32_t activation_min = std::numeric_limits::min(); const int32_t activation_max = std::numeric_limits::max(); @@ -88,10 +82,10 @@ Tensor& quantized_add_out( arm_cmsis_nn_status status = arm_elementwise_add_s8( input1_int8.const_data_ptr(), input2_int8.const_data_ptr(), - static_cast(zp1), + -static_cast(zp1), input1_mult, input1_shift_val, - static_cast(zp2), + -static_cast(zp2), input2_mult, input2_shift_val, left_shift, @@ -99,9 +93,9 @@ Tensor& quantized_add_out( static_cast(out_zp), output_mult, output_shift_val, - static_cast(out.numel()), activation_min, - activation_max); + activation_max, + static_cast(out.numel())); if (status != ARM_CMSIS_NN_SUCCESS) { ET_LOG( @@ -119,32 +113,5 @@ Tensor& quantized_add_out( return out; } -// Stub Implementation: Non-out variant for compatibility (functional variant) -// EXIR/ExecuTorch runs an out-variant pass that converts -// .default operations to .out variants before memory planning. -// In the pass we are calling quantized_add's default variant -// but ExecuTorch's kernel dispatch mechanism will end up calling the out -// variant. This stub is to make sure that compiler doesn't complain. -Tensor quantized_add( - KernelRuntimeContext& context, - const Tensor& input1_int8, - const Scalar& input1_zero_point, - const Scalar& input1_multiplier, - const Scalar& input1_shift, - const Tensor& input2_int8, - const Scalar& input2_zero_point, - const Scalar& input2_multiplier, - const Scalar& input2_shift, - const Scalar& output_zero_point, - const Scalar& output_multiplier, - const Scalar& output_shift) { - ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes()); - - // Crash on Debug builds if invoked - assert(False); - // This is to make sure compiler doesn't complain. - return const_cast(input1_int8); -} - } // namespace native } // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index d642531e950..286f938ccc9 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -1,13 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch from executorch.backends.cortex_m.passes.passes_utils import ( - dequantize_per_tensor_cmsis, - quantize_per_tensor_cmsis, + requantize_cmsis, + SHIFT_INT8, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -111,52 +112,6 @@ def dequantize_per_tensor_impl( "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" ) - -@register_fake("cortex_m::quantized_add") -def quantized_add_meta( - self: torch.Tensor, - self_zero_point: int, - self_multiplier: int, - self_shift: int, - other: torch.Tensor, - other_zero_point: int, - other_multiplier: int, - other_shift: int, - output_zero_point: int, - output_multiplier: int, - output_shift: int, -) -> torch.Tensor: - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) - - -@impl(lib, "quantized_add", "CompositeExplicitAutograd") -def quantized_add_impl( - self: torch.Tensor, - self_zero_point: int, - self_multiplier: int, - self_shift: int, - other: torch.Tensor, - other_zero_point: int, - other_multiplier: int, - other_shift: int, - output_zero_point: int, - output_multiplier: int, - output_shift: int, -) -> torch.Tensor: - self_fp = dequantize_per_tensor_cmsis( - self, self_zero_point, self_multiplier, self_shift - ) - other_fp = dequantize_per_tensor_cmsis( - other, other_zero_point, other_multiplier, other_shift - ) - result_fp = self_fp + other_fp - result_quantized = quantize_per_tensor_cmsis( - result_fp, output_zero_point, output_multiplier, output_shift - ) - return result_quantized - - # Define the operator schema with multipliers and shifts (11 args + out tensor) lib.define( "quantized_add.out(" @@ -167,9 +122,8 @@ def quantized_add_impl( ) -# Fake meta function for shape and dtype inference during compilation -@register_fake("cortex_m::quantized_add.out") -def quantized_add_out_meta( +@register_fake("cortex_m::quantized_add") +def quantized_add_meta( self: torch.Tensor, self_zero_point: int, self_multiplier: int, @@ -181,19 +135,13 @@ def quantized_add_out_meta( output_zero_point: int, output_multiplier: int, output_shift: int, - out: torch.Tensor, ) -> torch.Tensor: - # Validate against correct broadcasted shape - expected_shape = torch.broadcast_shapes(self.shape, other.shape) - assert ( - out.shape == expected_shape - ), f"Output shape {out.shape} must match broadcasted shape {expected_shape}" - return out + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) -# Actual implementation delegating to backend or custom kernel -@impl(lib, "quantized_add.out", "CompositeExplicitAutograd") -def quantized_add_out_impl( +@impl(lib, "quantized_add", "CompositeExplicitAutograd") +def quantized_add_impl( self: torch.Tensor, self_zero_point: int, self_multiplier: int, @@ -205,24 +153,17 @@ def quantized_add_out_impl( output_zero_point: int, output_multiplier: int, output_shift: int, - *, - out: torch.Tensor, ) -> torch.Tensor: - self_fp = dequantize_per_tensor_cmsis( - self, self_zero_point, self_multiplier, self_shift - ) - other_fp = dequantize_per_tensor_cmsis( - other, other_zero_point, other_multiplier, other_shift - ) - result_fp = self_fp + other_fp - result_quantized = quantize_per_tensor_cmsis( - result_fp, output_zero_point, output_multiplier, output_shift - ) + self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 + self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift) - # Write into the provided output tensor - out.copy_(result_quantized) + other_shifted = (other.to(torch.int32) - other_zero_point) << SHIFT_INT8 + other_fp = requantize_cmsis(other_shifted, other_multiplier, other_shift) - return out + result_fp = self_fp + other_fp + result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) + result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + return result # =================================================================== diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index b41c0c68fa5..81ebeafc778 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,12 +17,6 @@ - arg_meta: null kernel_name: cortex_m::dequantize_per_tensor_out -- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor - variants: function - kernels: - - arg_meta: null - kernel_name: cortex_m::quantized_add - - func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index d31a0b894f6..d4b8cebe400 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -9,6 +9,9 @@ QuantizedOpFusionPass, ReplaceQuantNodesPass, ) +from executorch.backends.transforms.replace_scalar_with_tensor import ( + ReplaceScalarWithTensorArgPass, +) from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.exir.pass_base import ExportPass @@ -16,6 +19,7 @@ class CortexMPassManager(XNNPACKPassManager): pass_list: list[ExportPass] = [ + ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, QuantizedLinearFusionPass, diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index 7155f997bf4..b045005d34d 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,6 +13,9 @@ from torch.fx import Node +# L-shift value used in CMSIS-NN for int8 operations +SHIFT_INT8 = 20 + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int @@ -41,6 +45,21 @@ def quantize_per_tensor_cmsis( return quantized.clamp(qmin, qmax).to(torch.int8) +def requantize_cmsis( + tensor: torch.Tensor, + multiplier: int, + shift: int, +) -> torch.Tensor: + """ + Simulate CMSIS-NN fixed-point requantization: + result = round(tensor * multiplier / (2 ^ shift)) + with double rounding + """ + multiplied = torch.round(tensor.to(torch.int64) * multiplier) + shifted = torch.round(multiplied / (2 ** (31 - shift))) + return shifted.to(torch.int32) + + def extract_scalar_value(node_arg) -> float: """ Extract scalar value from various PyTorch scalar representations. @@ -83,13 +102,14 @@ def is_qualified_int8_node(args) -> bool: def quantize_multiplier_aot(scale: float) -> tuple[int, int]: if scale == 0.0: return 0, 0 - mantissa, exponent = math.frexp(scale) - shift = -exponent + mantissa, shift = math.frexp(scale) q_fixed = int(round(mantissa * (1 << 31))) if q_fixed == (1 << 31): q_fixed //= 2 - shift -= 1 - multiplier = max(-2147483648, min(2147483647, q_fixed)) + shift += 1 + multiplier = max( + torch.iinfo(torch.int32).min, min(torch.iinfo(torch.int32).max, q_fixed) + ) return multiplier, shift diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index eebf6866d83..888155dcfd0 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,6 +14,7 @@ from executorch.backends.cortex_m.passes.passes_utils import ( extract_scalar_value, quantize_multiplier_aot, + SHIFT_INT8, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -58,7 +60,16 @@ def _get_quant_targets(self) -> Set: def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: """Check if node is a supported binary operation.""" - return node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + is_supported = ( + node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + ) + if not is_supported: + return False + + shape1 = node.args[0].meta["val"].shape + shape2 = node.args[1].meta["val"].shape + is_broadcast = shape1 != shape2 + return not is_broadcast def _is_dequant_node(self, node: torch.fx.Node) -> bool: """Check if node is a dequantize operation.""" @@ -163,16 +174,18 @@ def _fuse_quantized_binary_patterns( zp2_val = int(extract_scalar_value(zero_point2)) output_zp_val = int(extract_scalar_value(output_zero_point)) + max_scale_2x = 2 * max(scale1_val, scale2_val) # AoT COMPUTATION: Calculate multipliers and shifts + input1_mult, input1_shift = quantize_multiplier_aot( - scale1_val / output_scale_val + scale1_val / max_scale_2x ) input2_mult, input2_shift = quantize_multiplier_aot( - scale2_val / output_scale_val + scale2_val / max_scale_2x ) output_mult, output_shift = quantize_multiplier_aot( - 1.0 - ) # Output multiplier is 1 + max_scale_2x / (output_scale_val * (1 << SHIFT_INT8)) + ) logger.info("AoT computed parameters:") logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index b7b0ffcbfbc..bd7de56c8df 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -110,6 +110,10 @@ class CortexMAlphaAdd(ModelAlpha): CortexMScalarAdd(), (1000.0, torch.ones(2, 2)), ), + "tensor_tensor": McuTestCase( + CortexMTensorAdd(), + (torch.rand(2, 2) * 10, torch.rand(2, 2)), + ), "broadcast_1": McuTestCase( CortexMTensorAdd(), (torch.ones(1), torch.ones(2, 2, 2, 2)), @@ -136,15 +140,38 @@ class CortexMAlphaAdd(ModelAlpha): dialect_xfails = { - "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "self_rank_1": ("Output 0 does not match reference output", AssertionError), - "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), - "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), - "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), - "self_rank_5": ("Output 0 does not match reference output", AssertionError), - "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "broadcast_3": ("Output 0 does not match reference output", AssertionError), - "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), + "self_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "scalar_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "tensor_scalar": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "scalar_tensor": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "broadcast_1": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "broadcast_2": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "broadcast_3": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "alpha": ( + "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", + AssertionError, + ), } @@ -157,19 +184,38 @@ def test_dialect_add(test_case): implementation_xfails = { - "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "self_rank_1": ("Output 0 does not match reference output", AssertionError), - "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), - "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), - "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), - "self_rank_5": ("Output 0 does not match reference output", AssertionError), - "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "tensor_scalar": ("Output 0 does not match reference output", AssertionError), - "scalar_tensor": ("Output 0 does not match reference output", AssertionError), - "broadcast_1": ("Output 0 does not match reference output", AssertionError), - "broadcast_2": ("Output 0 does not match reference output", AssertionError), - "broadcast_3": ("Output 0 does not match reference output", AssertionError), - "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), + "self_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "scalar_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "tensor_scalar": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "scalar_tensor": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "broadcast_1": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "broadcast_2": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "broadcast_3": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "alpha": ( + "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", + AssertionError, + ), } From 61dddb5f8da8cffd55a70793e6783dd151193341 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 27 Oct 2025 13:21:18 +0100 Subject: [PATCH 2/2] Add shape check Signed-off-by: Adrian Lundell --- backends/cortex_m/ops/cortex_m_ops_common.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index eaa7027e46c..c7e6cc8a389 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -49,6 +50,12 @@ inline void validate_cmsis_nn_tensor_requirements( "Output dtype must be %hhd, got %hhd", expected_dtype, output.scalar_type()); + ET_CHECK_MSG( + input1.sizes() == input2.sizes(), + "Input1 and Input2 must have the same sizes"); + ET_CHECK_MSG( + output.sizes() == input1.sizes(), + "Output must have the same sizes as inputs"); // Dim order consistency ET_CHECK_MSG(