Skip to content

Commit 8c09c9a

Browse files
committed
Cortex_m backend: Fix add implementation
- Call CMSIS-NN kernel with correct argument order and signs - Change python implementation to reflect CMSIS-NN behaviour - Fix scale calculations - Remove broken broadcasting support - Add pass to lower scalar version ops - Remove unused definitions/ implementations in operators.py, operators.yaml and op_quantized_add.cpp Note: arm_elementwise_add_s8 does not natively support broadcasting, so simply resizing the output tensor will not work. Enabling this in an efficient way is not stragiht forward, so avoid fusing these ops for now to avoid break graphs. Signed-off-by: Adrian Lundell <[email protected]> Change-Id: Id76db13848f2ce67d7527f40d31c06db663af8fa
1 parent 57a7903 commit 8c09c9a

File tree

7 files changed

+138
-152
lines changed

7 files changed

+138
-152
lines changed

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
@@ -47,13 +48,6 @@ Tensor& quantized_add_out(
4748
output_shift,
4849
out);
4950

50-
// Broadcast if needed
51-
auto result = resize_to_broadcast_target_size(input1_int8, input2_int8, out);
52-
ET_CHECK_MSG(
53-
(result == Error::Ok),
54-
"Failed to resize output tensor. Status: [%d]",
55-
result);
56-
5751
ET_LOG(
5852
Info,
5953
"quantized_add_out: input1_int8.sizes() = %zu",
@@ -69,7 +63,7 @@ Tensor& quantized_add_out(
6963
int32_t output_mult = extractScalarToInt32(output_multiplier);
7064
int output_shift_val = extractScalarToInt(output_shift);
7165

72-
// Left shift to maximize precision (tune as needed)
66+
// Left shift to maximize precision
7367
const int32_t left_shift = 20;
7468
const int32_t activation_min = std::numeric_limits<int8_t>::min();
7569
const int32_t activation_max = std::numeric_limits<int8_t>::max();
@@ -88,20 +82,20 @@ Tensor& quantized_add_out(
8882
arm_cmsis_nn_status status = arm_elementwise_add_s8(
8983
input1_int8.const_data_ptr<int8_t>(),
9084
input2_int8.const_data_ptr<int8_t>(),
91-
static_cast<int32_t>(zp1),
85+
-static_cast<int32_t>(zp1),
9286
input1_mult,
9387
input1_shift_val,
94-
static_cast<int32_t>(zp2),
88+
-static_cast<int32_t>(zp2),
9589
input2_mult,
9690
input2_shift_val,
9791
left_shift,
9892
out.mutable_data_ptr<int8_t>(),
9993
static_cast<int32_t>(out_zp),
10094
output_mult,
10195
output_shift_val,
102-
static_cast<int32_t>(out.numel()),
10396
activation_min,
104-
activation_max);
97+
activation_max,
98+
static_cast<int32_t>(out.numel()));
10599

106100
if (status != ARM_CMSIS_NN_SUCCESS) {
107101
ET_LOG(
@@ -119,32 +113,5 @@ Tensor& quantized_add_out(
119113
return out;
120114
}
121115

122-
// Stub Implementation: Non-out variant for compatibility (functional variant)
123-
// EXIR/ExecuTorch runs an out-variant pass that converts
124-
// .default operations to .out variants before memory planning.
125-
// In the pass we are calling quantized_add's default variant
126-
// but ExecuTorch's kernel dispatch mechanism will end up calling the out
127-
// variant. This stub is to make sure that compiler doesn't complain.
128-
Tensor quantized_add(
129-
KernelRuntimeContext& context,
130-
const Tensor& input1_int8,
131-
const Scalar& input1_zero_point,
132-
const Scalar& input1_multiplier,
133-
const Scalar& input1_shift,
134-
const Tensor& input2_int8,
135-
const Scalar& input2_zero_point,
136-
const Scalar& input2_multiplier,
137-
const Scalar& input2_shift,
138-
const Scalar& output_zero_point,
139-
const Scalar& output_multiplier,
140-
const Scalar& output_shift) {
141-
ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes());
142-
143-
// Crash on Debug builds if invoked
144-
assert(False);
145-
// This is to make sure compiler doesn't complain.
146-
return const_cast<Tensor&>(input1_int8);
147-
}
148-
149116
} // namespace native
150117
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

78
import torch
89
from executorch.backends.cortex_m.passes.passes_utils import (
9-
dequantize_per_tensor_cmsis,
10-
quantize_per_tensor_cmsis,
10+
requantize_cmsis,
11+
SHIFT_INT8,
1112
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314

@@ -111,52 +112,6 @@ def dequantize_per_tensor_impl(
111112
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
112113
)
113114

114-
115-
@register_fake("cortex_m::quantized_add")
116-
def quantized_add_meta(
117-
self: torch.Tensor,
118-
self_zero_point: int,
119-
self_multiplier: int,
120-
self_shift: int,
121-
other: torch.Tensor,
122-
other_zero_point: int,
123-
other_multiplier: int,
124-
other_shift: int,
125-
output_zero_point: int,
126-
output_multiplier: int,
127-
output_shift: int,
128-
) -> torch.Tensor:
129-
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
130-
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
131-
132-
133-
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
134-
def quantized_add_impl(
135-
self: torch.Tensor,
136-
self_zero_point: int,
137-
self_multiplier: int,
138-
self_shift: int,
139-
other: torch.Tensor,
140-
other_zero_point: int,
141-
other_multiplier: int,
142-
other_shift: int,
143-
output_zero_point: int,
144-
output_multiplier: int,
145-
output_shift: int,
146-
) -> torch.Tensor:
147-
self_fp = dequantize_per_tensor_cmsis(
148-
self, self_zero_point, self_multiplier, self_shift
149-
)
150-
other_fp = dequantize_per_tensor_cmsis(
151-
other, other_zero_point, other_multiplier, other_shift
152-
)
153-
result_fp = self_fp + other_fp
154-
result_quantized = quantize_per_tensor_cmsis(
155-
result_fp, output_zero_point, output_multiplier, output_shift
156-
)
157-
return result_quantized
158-
159-
160115
# Define the operator schema with multipliers and shifts (11 args + out tensor)
161116
lib.define(
162117
"quantized_add.out("
@@ -167,9 +122,8 @@ def quantized_add_impl(
167122
)
168123

169124

170-
# Fake meta function for shape and dtype inference during compilation
171-
@register_fake("cortex_m::quantized_add.out")
172-
def quantized_add_out_meta(
125+
@register_fake("cortex_m::quantized_add")
126+
def quantized_add_meta(
173127
self: torch.Tensor,
174128
self_zero_point: int,
175129
self_multiplier: int,
@@ -181,19 +135,13 @@ def quantized_add_out_meta(
181135
output_zero_point: int,
182136
output_multiplier: int,
183137
output_shift: int,
184-
out: torch.Tensor,
185138
) -> torch.Tensor:
186-
# Validate against correct broadcasted shape
187-
expected_shape = torch.broadcast_shapes(self.shape, other.shape)
188-
assert (
189-
out.shape == expected_shape
190-
), f"Output shape {out.shape} must match broadcasted shape {expected_shape}"
191-
return out
139+
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
140+
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
192141

193142

194-
# Actual implementation delegating to backend or custom kernel
195-
@impl(lib, "quantized_add.out", "CompositeExplicitAutograd")
196-
def quantized_add_out_impl(
143+
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
144+
def quantized_add_impl(
197145
self: torch.Tensor,
198146
self_zero_point: int,
199147
self_multiplier: int,
@@ -205,24 +153,17 @@ def quantized_add_out_impl(
205153
output_zero_point: int,
206154
output_multiplier: int,
207155
output_shift: int,
208-
*,
209-
out: torch.Tensor,
210156
) -> torch.Tensor:
211-
self_fp = dequantize_per_tensor_cmsis(
212-
self, self_zero_point, self_multiplier, self_shift
213-
)
214-
other_fp = dequantize_per_tensor_cmsis(
215-
other, other_zero_point, other_multiplier, other_shift
216-
)
217-
result_fp = self_fp + other_fp
218-
result_quantized = quantize_per_tensor_cmsis(
219-
result_fp, output_zero_point, output_multiplier, output_shift
220-
)
157+
self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8
158+
self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift)
221159

222-
# Write into the provided output tensor
223-
out.copy_(result_quantized)
160+
other_shifted = (other.to(torch.int32) - other_zero_point) << SHIFT_INT8
161+
other_fp = requantize_cmsis(other_shifted, other_multiplier, other_shift)
224162

225-
return out
163+
result_fp = self_fp + other_fp
164+
result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift)
165+
result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8)
166+
return result
226167

227168

228169
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -16,12 +17,6 @@
1617
- arg_meta: null
1718
kernel_name: cortex_m::dequantize_per_tensor_out
1819

19-
- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor
20-
variants: function
21-
kernels:
22-
- arg_meta: null
23-
kernel_name: cortex_m::quantized_add
24-
2520
- func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!)
2621
variants: function
2722
kernels:

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
QuantizedOpFusionPass,
1010
ReplaceQuantNodesPass,
1111
)
12+
from executorch.backends.transforms.replace_scalar_with_tensor import (
13+
ReplaceScalarWithTensorArgPass,
14+
)
1215
from executorch.backends.xnnpack._passes import XNNPACKPassManager
1316
from executorch.exir.pass_base import ExportPass
1417

1518

1619
class CortexMPassManager(XNNPACKPassManager):
1720

1821
pass_list: list[ExportPass] = [
22+
ReplaceScalarWithTensorArgPass,
1923
ReplaceQuantNodesPass,
2024
QuantizedOpFusionPass,
2125
QuantizedLinearFusionPass,

backends/cortex_m/passes/passes_utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -12,6 +13,9 @@
1213

1314
from torch.fx import Node
1415

16+
# L-shift value used in CMSIS-NN for int8 operations
17+
SHIFT_INT8 = 20
18+
1519

1620
def dequantize_per_tensor_cmsis(
1721
qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int
@@ -41,6 +45,21 @@ def quantize_per_tensor_cmsis(
4145
return quantized.clamp(qmin, qmax).to(torch.int8)
4246

4347

48+
def requantize_cmsis(
49+
tensor: torch.Tensor,
50+
multiplier: int,
51+
shift: int,
52+
) -> torch.Tensor:
53+
"""
54+
Simulate CMSIS-NN fixed-point requantization:
55+
result = round(tensor * multiplier / (2 ^ shift))
56+
with double rounding
57+
"""
58+
multiplied = torch.round(tensor.to(torch.int64) * multiplier)
59+
shifted = torch.round(multiplied / (2 ** (31 - shift)))
60+
return shifted.to(torch.int32)
61+
62+
4463
def extract_scalar_value(node_arg) -> float:
4564
"""
4665
Extract scalar value from various PyTorch scalar representations.
@@ -83,13 +102,14 @@ def is_qualified_int8_node(args) -> bool:
83102
def quantize_multiplier_aot(scale: float) -> tuple[int, int]:
84103
if scale == 0.0:
85104
return 0, 0
86-
mantissa, exponent = math.frexp(scale)
87-
shift = -exponent
105+
mantissa, shift = math.frexp(scale)
88106
q_fixed = int(round(mantissa * (1 << 31)))
89107
if q_fixed == (1 << 31):
90108
q_fixed //= 2
91-
shift -= 1
92-
multiplier = max(-2147483648, min(2147483647, q_fixed))
109+
shift += 1
110+
multiplier = max(
111+
torch.iinfo(torch.int32).min, min(torch.iinfo(torch.int32).max, q_fixed)
112+
)
93113
return multiplier, shift
94114

95115

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -13,6 +14,7 @@
1314
from executorch.backends.cortex_m.passes.passes_utils import (
1415
extract_scalar_value,
1516
quantize_multiplier_aot,
17+
SHIFT_INT8,
1618
)
1719
from executorch.exir.dialects._ops import ops as exir_ops
1820
from executorch.exir.pass_base import ExportPass
@@ -58,7 +60,16 @@ def _get_quant_targets(self) -> Set:
5860

5961
def _is_supported_binary_op(self, node: torch.fx.Node) -> bool:
6062
"""Check if node is a supported binary operation."""
61-
return node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING
63+
is_supported = (
64+
node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING
65+
)
66+
if not is_supported:
67+
return False
68+
69+
shape1 = node.args[0].meta["val"].shape
70+
shape2 = node.args[1].meta["val"].shape
71+
is_broadcast = shape1 != shape2
72+
return not is_broadcast
6273

6374
def _is_dequant_node(self, node: torch.fx.Node) -> bool:
6475
"""Check if node is a dequantize operation."""
@@ -163,16 +174,18 @@ def _fuse_quantized_binary_patterns(
163174
zp2_val = int(extract_scalar_value(zero_point2))
164175
output_zp_val = int(extract_scalar_value(output_zero_point))
165176

177+
max_scale_2x = 2 * max(scale1_val, scale2_val)
166178
# AoT COMPUTATION: Calculate multipliers and shifts
179+
167180
input1_mult, input1_shift = quantize_multiplier_aot(
168-
scale1_val / output_scale_val
181+
scale1_val / max_scale_2x
169182
)
170183
input2_mult, input2_shift = quantize_multiplier_aot(
171-
scale2_val / output_scale_val
184+
scale2_val / max_scale_2x
172185
)
173186
output_mult, output_shift = quantize_multiplier_aot(
174-
1.0
175-
) # Output multiplier is 1
187+
max_scale_2x / (output_scale_val * (1 << SHIFT_INT8))
188+
)
176189

177190
logger.info("AoT computed parameters:")
178191
logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}")

0 commit comments

Comments
 (0)