Skip to content

Commit 0e18b9f

Browse files
Cortex_m backend: Fix add implementation (pytorch#15100)
- Call CMSIS-NN kernel with correct argument order and signs - Change python implementation to reflect CMSIS-NN behaviour - Fix scale calculations - Remove broken broadcasting support - Add pass to lower scalar version ops - Remove unused definitions/ implementations in operators.py, operators.yaml and op_quantized_add.cpp Note: arm_elementwise_add_s8 does not natively support broadcasting, so simply resizing the output tensor will not work. Enabling this in an efficient way is not straight forward, so avoid fusing these ops for now to avoid break graphs. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Adrian Lundell <[email protected]>
1 parent c1c5a71 commit 0e18b9f

File tree

8 files changed

+145
-152
lines changed

8 files changed

+145
-152
lines changed

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
@@ -49,6 +50,12 @@ inline void validate_cmsis_nn_tensor_requirements(
4950
"Output dtype must be %hhd, got %hhd",
5051
expected_dtype,
5152
output.scalar_type());
53+
ET_CHECK_MSG(
54+
input1.sizes() == input2.sizes(),
55+
"Input1 and Input2 must have the same sizes");
56+
ET_CHECK_MSG(
57+
output.sizes() == input1.sizes(),
58+
"Output must have the same sizes as inputs");
5259

5360
// Dim order consistency
5461
ET_CHECK_MSG(

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
@@ -47,13 +48,6 @@ Tensor& quantized_add_out(
4748
output_shift,
4849
out);
4950

50-
// Broadcast if needed
51-
auto result = resize_to_broadcast_target_size(input1_int8, input2_int8, out);
52-
ET_CHECK_MSG(
53-
(result == Error::Ok),
54-
"Failed to resize output tensor. Status: [%d]",
55-
result);
56-
5751
ET_LOG(
5852
Info,
5953
"quantized_add_out: input1_int8.sizes() = %zu",
@@ -69,7 +63,7 @@ Tensor& quantized_add_out(
6963
int32_t output_mult = extractScalarToInt32(output_multiplier);
7064
int output_shift_val = extractScalarToInt(output_shift);
7165

72-
// Left shift to maximize precision (tune as needed)
66+
// Left shift to maximize precision
7367
const int32_t left_shift = 20;
7468
const int32_t activation_min = std::numeric_limits<int8_t>::min();
7569
const int32_t activation_max = std::numeric_limits<int8_t>::max();
@@ -88,20 +82,20 @@ Tensor& quantized_add_out(
8882
arm_cmsis_nn_status status = arm_elementwise_add_s8(
8983
input1_int8.const_data_ptr<int8_t>(),
9084
input2_int8.const_data_ptr<int8_t>(),
91-
static_cast<int32_t>(zp1),
85+
-static_cast<int32_t>(zp1),
9286
input1_mult,
9387
input1_shift_val,
94-
static_cast<int32_t>(zp2),
88+
-static_cast<int32_t>(zp2),
9589
input2_mult,
9690
input2_shift_val,
9791
left_shift,
9892
out.mutable_data_ptr<int8_t>(),
9993
static_cast<int32_t>(out_zp),
10094
output_mult,
10195
output_shift_val,
102-
static_cast<int32_t>(out.numel()),
10396
activation_min,
104-
activation_max);
97+
activation_max,
98+
static_cast<int32_t>(out.numel()));
10599

106100
if (status != ARM_CMSIS_NN_SUCCESS) {
107101
ET_LOG(
@@ -119,32 +113,5 @@ Tensor& quantized_add_out(
119113
return out;
120114
}
121115

122-
// Stub Implementation: Non-out variant for compatibility (functional variant)
123-
// EXIR/ExecuTorch runs an out-variant pass that converts
124-
// .default operations to .out variants before memory planning.
125-
// In the pass we are calling quantized_add's default variant
126-
// but ExecuTorch's kernel dispatch mechanism will end up calling the out
127-
// variant. This stub is to make sure that compiler doesn't complain.
128-
Tensor quantized_add(
129-
KernelRuntimeContext& context,
130-
const Tensor& input1_int8,
131-
const Scalar& input1_zero_point,
132-
const Scalar& input1_multiplier,
133-
const Scalar& input1_shift,
134-
const Tensor& input2_int8,
135-
const Scalar& input2_zero_point,
136-
const Scalar& input2_multiplier,
137-
const Scalar& input2_shift,
138-
const Scalar& output_zero_point,
139-
const Scalar& output_multiplier,
140-
const Scalar& output_shift) {
141-
ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes());
142-
143-
// Crash on Debug builds if invoked
144-
assert(False);
145-
// This is to make sure compiler doesn't complain.
146-
return const_cast<Tensor&>(input1_int8);
147-
}
148-
149116
} // namespace native
150117
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

78
import torch
89
from executorch.backends.cortex_m.passes.passes_utils import (
9-
dequantize_per_tensor_cmsis,
10-
quantize_per_tensor_cmsis,
10+
requantize_cmsis,
11+
SHIFT_INT8,
1112
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314

@@ -111,52 +112,6 @@ def dequantize_per_tensor_impl(
111112
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
112113
)
113114

114-
115-
@register_fake("cortex_m::quantized_add")
116-
def quantized_add_meta(
117-
self: torch.Tensor,
118-
self_zero_point: int,
119-
self_multiplier: int,
120-
self_shift: int,
121-
other: torch.Tensor,
122-
other_zero_point: int,
123-
other_multiplier: int,
124-
other_shift: int,
125-
output_zero_point: int,
126-
output_multiplier: int,
127-
output_shift: int,
128-
) -> torch.Tensor:
129-
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
130-
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
131-
132-
133-
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
134-
def quantized_add_impl(
135-
self: torch.Tensor,
136-
self_zero_point: int,
137-
self_multiplier: int,
138-
self_shift: int,
139-
other: torch.Tensor,
140-
other_zero_point: int,
141-
other_multiplier: int,
142-
other_shift: int,
143-
output_zero_point: int,
144-
output_multiplier: int,
145-
output_shift: int,
146-
) -> torch.Tensor:
147-
self_fp = dequantize_per_tensor_cmsis(
148-
self, self_zero_point, self_multiplier, self_shift
149-
)
150-
other_fp = dequantize_per_tensor_cmsis(
151-
other, other_zero_point, other_multiplier, other_shift
152-
)
153-
result_fp = self_fp + other_fp
154-
result_quantized = quantize_per_tensor_cmsis(
155-
result_fp, output_zero_point, output_multiplier, output_shift
156-
)
157-
return result_quantized
158-
159-
160115
# Define the operator schema with multipliers and shifts (11 args + out tensor)
161116
lib.define(
162117
"quantized_add.out("
@@ -167,9 +122,8 @@ def quantized_add_impl(
167122
)
168123

169124

170-
# Fake meta function for shape and dtype inference during compilation
171-
@register_fake("cortex_m::quantized_add.out")
172-
def quantized_add_out_meta(
125+
@register_fake("cortex_m::quantized_add")
126+
def quantized_add_meta(
173127
self: torch.Tensor,
174128
self_zero_point: int,
175129
self_multiplier: int,
@@ -181,19 +135,13 @@ def quantized_add_out_meta(
181135
output_zero_point: int,
182136
output_multiplier: int,
183137
output_shift: int,
184-
out: torch.Tensor,
185138
) -> torch.Tensor:
186-
# Validate against correct broadcasted shape
187-
expected_shape = torch.broadcast_shapes(self.shape, other.shape)
188-
assert (
189-
out.shape == expected_shape
190-
), f"Output shape {out.shape} must match broadcasted shape {expected_shape}"
191-
return out
139+
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
140+
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
192141

193142

194-
# Actual implementation delegating to backend or custom kernel
195-
@impl(lib, "quantized_add.out", "CompositeExplicitAutograd")
196-
def quantized_add_out_impl(
143+
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
144+
def quantized_add_impl(
197145
self: torch.Tensor,
198146
self_zero_point: int,
199147
self_multiplier: int,
@@ -205,24 +153,17 @@ def quantized_add_out_impl(
205153
output_zero_point: int,
206154
output_multiplier: int,
207155
output_shift: int,
208-
*,
209-
out: torch.Tensor,
210156
) -> torch.Tensor:
211-
self_fp = dequantize_per_tensor_cmsis(
212-
self, self_zero_point, self_multiplier, self_shift
213-
)
214-
other_fp = dequantize_per_tensor_cmsis(
215-
other, other_zero_point, other_multiplier, other_shift
216-
)
217-
result_fp = self_fp + other_fp
218-
result_quantized = quantize_per_tensor_cmsis(
219-
result_fp, output_zero_point, output_multiplier, output_shift
220-
)
157+
self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8
158+
self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift)
221159

222-
# Write into the provided output tensor
223-
out.copy_(result_quantized)
160+
other_shifted = (other.to(torch.int32) - other_zero_point) << SHIFT_INT8
161+
other_fp = requantize_cmsis(other_shifted, other_multiplier, other_shift)
224162

225-
return out
163+
result_fp = self_fp + other_fp
164+
result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift)
165+
result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8)
166+
return result
226167

227168

228169
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -16,12 +17,6 @@
1617
- arg_meta: null
1718
kernel_name: cortex_m::dequantize_per_tensor_out
1819

19-
- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor
20-
variants: function
21-
kernels:
22-
- arg_meta: null
23-
kernel_name: cortex_m::quantized_add
24-
2520
- func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!)
2621
variants: function
2722
kernels:

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
QuantizedOpFusionPass,
1010
ReplaceQuantNodesPass,
1111
)
12+
from executorch.backends.transforms.replace_scalar_with_tensor import (
13+
ReplaceScalarWithTensorArgPass,
14+
)
1215
from executorch.backends.xnnpack._passes import XNNPACKPassManager
1316
from executorch.exir.pass_base import ExportPass
1417

1518

1619
class CortexMPassManager(XNNPACKPassManager):
1720

1821
pass_list: list[ExportPass] = [
22+
ReplaceScalarWithTensorArgPass,
1923
ReplaceQuantNodesPass,
2024
QuantizedOpFusionPass,
2125
QuantizedLinearFusionPass,

backends/cortex_m/passes/passes_utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -12,6 +13,9 @@
1213

1314
from torch.fx import Node
1415

16+
# L-shift value used in CMSIS-NN for int8 operations
17+
SHIFT_INT8 = 20
18+
1519

1620
def dequantize_per_tensor_cmsis(
1721
qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int
@@ -41,6 +45,21 @@ def quantize_per_tensor_cmsis(
4145
return quantized.clamp(qmin, qmax).to(torch.int8)
4246

4347

48+
def requantize_cmsis(
49+
tensor: torch.Tensor,
50+
multiplier: int,
51+
shift: int,
52+
) -> torch.Tensor:
53+
"""
54+
Simulate CMSIS-NN fixed-point requantization:
55+
result = round(tensor * multiplier / (2 ^ shift))
56+
with double rounding
57+
"""
58+
multiplied = torch.round(tensor.to(torch.int64) * multiplier)
59+
shifted = torch.round(multiplied / (2 ** (31 - shift)))
60+
return shifted.to(torch.int32)
61+
62+
4463
def extract_scalar_value(node_arg) -> float:
4564
"""
4665
Extract scalar value from various PyTorch scalar representations.
@@ -83,13 +102,14 @@ def is_qualified_int8_node(args) -> bool:
83102
def quantize_multiplier_aot(scale: float) -> tuple[int, int]:
84103
if scale == 0.0:
85104
return 0, 0
86-
mantissa, exponent = math.frexp(scale)
87-
shift = -exponent
105+
mantissa, shift = math.frexp(scale)
88106
q_fixed = int(round(mantissa * (1 << 31)))
89107
if q_fixed == (1 << 31):
90108
q_fixed //= 2
91-
shift -= 1
92-
multiplier = max(-2147483648, min(2147483647, q_fixed))
109+
shift += 1
110+
multiplier = max(
111+
torch.iinfo(torch.int32).min, min(torch.iinfo(torch.int32).max, q_fixed)
112+
)
93113
return multiplier, shift
94114

95115

0 commit comments

Comments
 (0)