diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index eaa7027e46c..c7e6cc8a389 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -49,6 +50,12 @@ inline void validate_cmsis_nn_tensor_requirements( "Output dtype must be %hhd, got %hhd", expected_dtype, output.scalar_type()); + ET_CHECK_MSG( + input1.sizes() == input2.sizes(), + "Input1 and Input2 must have the same sizes"); + ET_CHECK_MSG( + output.sizes() == input1.sizes(), + "Output must have the same sizes as inputs"); // Dim order consistency ET_CHECK_MSG( diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index 044c2bd92d5..30be108ffcb 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -47,13 +48,6 @@ Tensor& quantized_add_out( output_shift, out); - // Broadcast if needed - auto result = resize_to_broadcast_target_size(input1_int8, input2_int8, out); - ET_CHECK_MSG( - (result == Error::Ok), - "Failed to resize output tensor. Status: [%d]", - result); - ET_LOG( Info, "quantized_add_out: input1_int8.sizes() = %zu", @@ -69,7 +63,7 @@ Tensor& quantized_add_out( int32_t output_mult = extractScalarToInt32(output_multiplier); int output_shift_val = extractScalarToInt(output_shift); - // Left shift to maximize precision (tune as needed) + // Left shift to maximize precision const int32_t left_shift = 20; const int32_t activation_min = std::numeric_limits::min(); const int32_t activation_max = std::numeric_limits::max(); @@ -88,10 +82,10 @@ Tensor& quantized_add_out( arm_cmsis_nn_status status = arm_elementwise_add_s8( input1_int8.const_data_ptr(), input2_int8.const_data_ptr(), - static_cast(zp1), + -static_cast(zp1), input1_mult, input1_shift_val, - static_cast(zp2), + -static_cast(zp2), input2_mult, input2_shift_val, left_shift, @@ -99,9 +93,9 @@ Tensor& quantized_add_out( static_cast(out_zp), output_mult, output_shift_val, - static_cast(out.numel()), activation_min, - activation_max); + activation_max, + static_cast(out.numel())); if (status != ARM_CMSIS_NN_SUCCESS) { ET_LOG( @@ -119,32 +113,5 @@ Tensor& quantized_add_out( return out; } -// Stub Implementation: Non-out variant for compatibility (functional variant) -// EXIR/ExecuTorch runs an out-variant pass that converts -// .default operations to .out variants before memory planning. -// In the pass we are calling quantized_add's default variant -// but ExecuTorch's kernel dispatch mechanism will end up calling the out -// variant. This stub is to make sure that compiler doesn't complain. -Tensor quantized_add( - KernelRuntimeContext& context, - const Tensor& input1_int8, - const Scalar& input1_zero_point, - const Scalar& input1_multiplier, - const Scalar& input1_shift, - const Tensor& input2_int8, - const Scalar& input2_zero_point, - const Scalar& input2_multiplier, - const Scalar& input2_shift, - const Scalar& output_zero_point, - const Scalar& output_multiplier, - const Scalar& output_shift) { - ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes()); - - // Crash on Debug builds if invoked - assert(False); - // This is to make sure compiler doesn't complain. - return const_cast(input1_int8); -} - } // namespace native } // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index d642531e950..286f938ccc9 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -1,13 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch from executorch.backends.cortex_m.passes.passes_utils import ( - dequantize_per_tensor_cmsis, - quantize_per_tensor_cmsis, + requantize_cmsis, + SHIFT_INT8, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -111,52 +112,6 @@ def dequantize_per_tensor_impl( "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" ) - -@register_fake("cortex_m::quantized_add") -def quantized_add_meta( - self: torch.Tensor, - self_zero_point: int, - self_multiplier: int, - self_shift: int, - other: torch.Tensor, - other_zero_point: int, - other_multiplier: int, - other_shift: int, - output_zero_point: int, - output_multiplier: int, - output_shift: int, -) -> torch.Tensor: - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) - - -@impl(lib, "quantized_add", "CompositeExplicitAutograd") -def quantized_add_impl( - self: torch.Tensor, - self_zero_point: int, - self_multiplier: int, - self_shift: int, - other: torch.Tensor, - other_zero_point: int, - other_multiplier: int, - other_shift: int, - output_zero_point: int, - output_multiplier: int, - output_shift: int, -) -> torch.Tensor: - self_fp = dequantize_per_tensor_cmsis( - self, self_zero_point, self_multiplier, self_shift - ) - other_fp = dequantize_per_tensor_cmsis( - other, other_zero_point, other_multiplier, other_shift - ) - result_fp = self_fp + other_fp - result_quantized = quantize_per_tensor_cmsis( - result_fp, output_zero_point, output_multiplier, output_shift - ) - return result_quantized - - # Define the operator schema with multipliers and shifts (11 args + out tensor) lib.define( "quantized_add.out(" @@ -167,9 +122,8 @@ def quantized_add_impl( ) -# Fake meta function for shape and dtype inference during compilation -@register_fake("cortex_m::quantized_add.out") -def quantized_add_out_meta( +@register_fake("cortex_m::quantized_add") +def quantized_add_meta( self: torch.Tensor, self_zero_point: int, self_multiplier: int, @@ -181,19 +135,13 @@ def quantized_add_out_meta( output_zero_point: int, output_multiplier: int, output_shift: int, - out: torch.Tensor, ) -> torch.Tensor: - # Validate against correct broadcasted shape - expected_shape = torch.broadcast_shapes(self.shape, other.shape) - assert ( - out.shape == expected_shape - ), f"Output shape {out.shape} must match broadcasted shape {expected_shape}" - return out + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) -# Actual implementation delegating to backend or custom kernel -@impl(lib, "quantized_add.out", "CompositeExplicitAutograd") -def quantized_add_out_impl( +@impl(lib, "quantized_add", "CompositeExplicitAutograd") +def quantized_add_impl( self: torch.Tensor, self_zero_point: int, self_multiplier: int, @@ -205,24 +153,17 @@ def quantized_add_out_impl( output_zero_point: int, output_multiplier: int, output_shift: int, - *, - out: torch.Tensor, ) -> torch.Tensor: - self_fp = dequantize_per_tensor_cmsis( - self, self_zero_point, self_multiplier, self_shift - ) - other_fp = dequantize_per_tensor_cmsis( - other, other_zero_point, other_multiplier, other_shift - ) - result_fp = self_fp + other_fp - result_quantized = quantize_per_tensor_cmsis( - result_fp, output_zero_point, output_multiplier, output_shift - ) + self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 + self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift) - # Write into the provided output tensor - out.copy_(result_quantized) + other_shifted = (other.to(torch.int32) - other_zero_point) << SHIFT_INT8 + other_fp = requantize_cmsis(other_shifted, other_multiplier, other_shift) - return out + result_fp = self_fp + other_fp + result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) + result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + return result # =================================================================== diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index b41c0c68fa5..81ebeafc778 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,12 +17,6 @@ - arg_meta: null kernel_name: cortex_m::dequantize_per_tensor_out -- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor - variants: function - kernels: - - arg_meta: null - kernel_name: cortex_m::quantized_add - - func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index d31a0b894f6..d4b8cebe400 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -9,6 +9,9 @@ QuantizedOpFusionPass, ReplaceQuantNodesPass, ) +from executorch.backends.transforms.replace_scalar_with_tensor import ( + ReplaceScalarWithTensorArgPass, +) from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.exir.pass_base import ExportPass @@ -16,6 +19,7 @@ class CortexMPassManager(XNNPACKPassManager): pass_list: list[ExportPass] = [ + ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, QuantizedLinearFusionPass, diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index 7155f997bf4..b045005d34d 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,6 +13,9 @@ from torch.fx import Node +# L-shift value used in CMSIS-NN for int8 operations +SHIFT_INT8 = 20 + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int @@ -41,6 +45,21 @@ def quantize_per_tensor_cmsis( return quantized.clamp(qmin, qmax).to(torch.int8) +def requantize_cmsis( + tensor: torch.Tensor, + multiplier: int, + shift: int, +) -> torch.Tensor: + """ + Simulate CMSIS-NN fixed-point requantization: + result = round(tensor * multiplier / (2 ^ shift)) + with double rounding + """ + multiplied = torch.round(tensor.to(torch.int64) * multiplier) + shifted = torch.round(multiplied / (2 ** (31 - shift))) + return shifted.to(torch.int32) + + def extract_scalar_value(node_arg) -> float: """ Extract scalar value from various PyTorch scalar representations. @@ -83,13 +102,14 @@ def is_qualified_int8_node(args) -> bool: def quantize_multiplier_aot(scale: float) -> tuple[int, int]: if scale == 0.0: return 0, 0 - mantissa, exponent = math.frexp(scale) - shift = -exponent + mantissa, shift = math.frexp(scale) q_fixed = int(round(mantissa * (1 << 31))) if q_fixed == (1 << 31): q_fixed //= 2 - shift -= 1 - multiplier = max(-2147483648, min(2147483647, q_fixed)) + shift += 1 + multiplier = max( + torch.iinfo(torch.int32).min, min(torch.iinfo(torch.int32).max, q_fixed) + ) return multiplier, shift diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index eebf6866d83..888155dcfd0 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,6 +14,7 @@ from executorch.backends.cortex_m.passes.passes_utils import ( extract_scalar_value, quantize_multiplier_aot, + SHIFT_INT8, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -58,7 +60,16 @@ def _get_quant_targets(self) -> Set: def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: """Check if node is a supported binary operation.""" - return node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + is_supported = ( + node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + ) + if not is_supported: + return False + + shape1 = node.args[0].meta["val"].shape + shape2 = node.args[1].meta["val"].shape + is_broadcast = shape1 != shape2 + return not is_broadcast def _is_dequant_node(self, node: torch.fx.Node) -> bool: """Check if node is a dequantize operation.""" @@ -163,16 +174,18 @@ def _fuse_quantized_binary_patterns( zp2_val = int(extract_scalar_value(zero_point2)) output_zp_val = int(extract_scalar_value(output_zero_point)) + max_scale_2x = 2 * max(scale1_val, scale2_val) # AoT COMPUTATION: Calculate multipliers and shifts + input1_mult, input1_shift = quantize_multiplier_aot( - scale1_val / output_scale_val + scale1_val / max_scale_2x ) input2_mult, input2_shift = quantize_multiplier_aot( - scale2_val / output_scale_val + scale2_val / max_scale_2x ) output_mult, output_shift = quantize_multiplier_aot( - 1.0 - ) # Output multiplier is 1 + max_scale_2x / (output_scale_val * (1 << SHIFT_INT8)) + ) logger.info("AoT computed parameters:") logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index b7b0ffcbfbc..bd7de56c8df 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -110,6 +110,10 @@ class CortexMAlphaAdd(ModelAlpha): CortexMScalarAdd(), (1000.0, torch.ones(2, 2)), ), + "tensor_tensor": McuTestCase( + CortexMTensorAdd(), + (torch.rand(2, 2) * 10, torch.rand(2, 2)), + ), "broadcast_1": McuTestCase( CortexMTensorAdd(), (torch.ones(1), torch.ones(2, 2, 2, 2)), @@ -136,15 +140,38 @@ class CortexMAlphaAdd(ModelAlpha): dialect_xfails = { - "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "self_rank_1": ("Output 0 does not match reference output", AssertionError), - "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), - "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), - "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), - "self_rank_5": ("Output 0 does not match reference output", AssertionError), - "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "broadcast_3": ("Output 0 does not match reference output", AssertionError), - "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), + "self_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "scalar_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "tensor_scalar": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "scalar_tensor": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "broadcast_1": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "broadcast_2": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "broadcast_3": ( + "Expected to find 'executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default' but did not find it - broadcasting not supported.", + RuntimeError, + ), + "alpha": ( + "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", + AssertionError, + ), } @@ -157,19 +184,38 @@ def test_dialect_add(test_case): implementation_xfails = { - "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "self_rank_1": ("Output 0 does not match reference output", AssertionError), - "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), - "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), - "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), - "self_rank_5": ("Output 0 does not match reference output", AssertionError), - "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), - "tensor_scalar": ("Output 0 does not match reference output", AssertionError), - "scalar_tensor": ("Output 0 does not match reference output", AssertionError), - "broadcast_1": ("Output 0 does not match reference output", AssertionError), - "broadcast_2": ("Output 0 does not match reference output", AssertionError), - "broadcast_3": ("Output 0 does not match reference output", AssertionError), - "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), + "self_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "scalar_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "tensor_scalar": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "scalar_tensor": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "broadcast_1": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "broadcast_2": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "broadcast_3": ( + "Missing operator: [2] aten::add.out - broadcasting not supported.", + RuntimeError, + ), + "alpha": ( + "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", + AssertionError, + ), }