Skip to content

Commit 1348f0c

Browse files
committed
Cortex_m backend: Add mul op
Signed-off-by: Adrian Lundell <[email protected]> Change-Id: Ic116e5294d9362f3a43655629d2a3c0f338a2fd5
1 parent 80c9040 commit 1348f0c

File tree

11 files changed

+233
-26
lines changed

11 files changed

+233
-26
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs
5757
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6061
)
6162

6263
# Generate C++ bindings to register kernels into Executorch
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
// Include CMSIS-NN headers with C linkage
11+
extern "C" {
12+
#include "arm_nnfunctions.h"
13+
}
14+
15+
namespace cortex_m {
16+
namespace native {
17+
namespace {
18+
19+
constexpr int32_t kInt8ActivationMin = std::numeric_limits<int8_t>::min();
20+
constexpr int32_t kInt8ActivationMax = std::numeric_limits<int8_t>::max();
21+
22+
} // namespace
23+
24+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
25+
26+
Tensor& quantized_mul_out(
27+
KernelRuntimeContext& context,
28+
const Tensor& input1_int8,
29+
const Scalar& input1_zero_point,
30+
const Tensor& input2_int8,
31+
const Scalar& input2_zero_point,
32+
const Scalar& output_zero_point,
33+
const Scalar& output_multiplier,
34+
const Scalar& output_shift,
35+
Tensor& out) {
36+
// Validate tensor types and quantization parameters
37+
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);
38+
39+
const Scalar kIdentityMultiplier(/*value=*/1);
40+
const Scalar kZeroShift(/*value=*/0);
41+
validate_quantization_params(
42+
input1_zero_point,
43+
kIdentityMultiplier,
44+
kZeroShift,
45+
input2_zero_point,
46+
kIdentityMultiplier,
47+
kZeroShift,
48+
output_zero_point,
49+
output_multiplier,
50+
output_shift,
51+
out);
52+
53+
// Extract quantization parameters
54+
const int32_t zp1 = extractScalarToInt32(input1_zero_point);
55+
const int32_t zp2 = extractScalarToInt32(input2_zero_point);
56+
const int32_t out_zp = extractScalarToInt32(output_zero_point);
57+
const int32_t output_mult = extractScalarToInt32(output_multiplier);
58+
const int32_t output_shift_val = extractScalarToInt(output_shift);
59+
60+
// Call CMSIS-NN elementwise multiply kernel
61+
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
62+
input1_int8.const_data_ptr<int8_t>(),
63+
input2_int8.const_data_ptr<int8_t>(),
64+
-static_cast<int32_t>(zp1),
65+
-static_cast<int32_t>(zp2),
66+
out.mutable_data_ptr<int8_t>(),
67+
static_cast<int32_t>(out_zp),
68+
output_mult,
69+
output_shift_val,
70+
kInt8ActivationMin,
71+
kInt8ActivationMax,
72+
static_cast<int32_t>(out.numel()));
73+
74+
if (status != ARM_CMSIS_NN_SUCCESS) {
75+
ET_LOG(
76+
Error,
77+
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
78+
status);
79+
context.fail(Error::Internal);
80+
return out;
81+
}
82+
83+
return out;
84+
}
85+
86+
} // namespace native
87+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,60 @@ def quantized_add_impl(
168168
return result
169169

170170

171+
# ===================================================================
172+
# QUANTIZED MUL OPERATION DEFINITION
173+
# ===================================================================
174+
lib.define(
175+
"quantized_mul("
176+
"Tensor self, Scalar self_zero_point, "
177+
"Tensor other, Scalar other_zero_point, "
178+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
179+
)
180+
lib.define(
181+
"quantized_mul.out("
182+
"Tensor self, Scalar self_zero_point, "
183+
"Tensor other, Scalar other_zero_point, "
184+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
185+
"*, Tensor(a!) out) -> Tensor(a!)"
186+
)
187+
188+
189+
@register_fake("cortex_m::quantized_mul")
190+
def quantized_mul_meta(
191+
self: torch.Tensor,
192+
self_zero_point: int,
193+
other: torch.Tensor,
194+
other_zero_point: int,
195+
output_zero_point: int,
196+
output_multiplier: int,
197+
output_shift: int,
198+
) -> torch.Tensor:
199+
# Broadcast to output shape
200+
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
201+
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
202+
203+
204+
@impl(lib, "quantized_mul", "CompositeExplicitAutograd")
205+
def quantized_mul_impl(
206+
self: torch.Tensor,
207+
self_zero_point: int,
208+
other: torch.Tensor,
209+
other_zero_point: int,
210+
output_zero_point: int,
211+
output_multiplier: int,
212+
output_shift: int,
213+
) -> torch.Tensor:
214+
# CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and
215+
# only uses the output multiplier/shift for rescaling. Mirror that here to
216+
# keep the composite implementation numerically aligned with the backend.
217+
self_int = self.to(torch.int32) - self_zero_point
218+
other_int = other.to(torch.int32) - other_zero_point
219+
result_fp = self_int * other_int
220+
result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift)
221+
result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8)
222+
return result
223+
224+
171225
# ===================================================================
172226
# QUANTIZED LINEAR OPERATION DEFINITION
173227
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@
2323
- arg_meta: null
2424
kernel_name: cortex_m::quantized_add_out
2525

26+
- func: cortex_m::quantized_mul.out(Tensor self, Scalar self_zero_point, Tensor other, Scalar other_zero_point, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!)
27+
variants: function
28+
kernels:
29+
- arg_meta: null
30+
kernel_name: cortex_m::quantized_mul_out
31+
2632
- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!)
2733
variants: function
2834
kernels:
2935
- arg_meta: null
30-
kernel_name: cortex_m::quantized_linear_out
36+
kernel_name: cortex_m::quantized_linear_out

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
from executorch.backends.arm._passes import (
8-
DecorateFp32toInt32CastingPass,
98
FoldAndAnnotateQParamsPass,
109
ScalarsToAttributePass,
1110
)
@@ -29,7 +28,6 @@ class CortexMPassManager(XNNPACKPassManager):
2928
ReplaceQuantNodesPass,
3029
QuantizedOpFusionPass,
3130
QuantizedLinearFusionPass,
32-
DecorateFp32toInt32CastingPass,
3331
]
3432

3533
pass_list_transform_for_annotation: list[ExportPass] = [

backends/cortex_m/passes/passes_utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,32 @@ def requantize_cmsis(
5050
multiplier: int,
5151
shift: int,
5252
) -> torch.Tensor:
53-
"""
54-
Simulate CMSIS-NN fixed-point requantization:
55-
result = round(tensor * multiplier / (2 ^ shift))
56-
with double rounding
57-
"""
58-
multiplied = torch.round(tensor.to(torch.int64) * multiplier)
59-
shifted = torch.round(multiplied / (2 ** (31 - shift)))
60-
return shifted.to(torch.int32)
53+
"""Simulate CMSIS-NN's arm_nn_requantize helper."""
54+
55+
tensor_64 = tensor.to(torch.int64)
56+
left_shift = max(shift, 0)
57+
right_shift = max(-shift, 0)
58+
59+
# Equivalent to val * (1 << LEFT_SHIFT(shift))
60+
value = tensor_64 << left_shift
61+
62+
# arm_nn_doubling_high_mult_no_sat(value, multiplier)
63+
product = value * int(multiplier)
64+
product = product + (1 << 30)
65+
result = product >> 31
66+
67+
if right_shift:
68+
remainder_mask = (1 << right_shift) - 1
69+
remainder = torch.bitwise_and(result, remainder_mask)
70+
result = result >> right_shift
71+
threshold = remainder_mask >> 1
72+
threshold_tensor = torch.full_like(result, threshold, dtype=torch.int64)
73+
threshold_tensor = torch.where(
74+
result < 0, threshold_tensor + 1, threshold_tensor
75+
)
76+
result = result + torch.where(remainder > threshold_tensor, 1, 0)
77+
78+
return result.to(torch.int32)
6179

6280

6381
def extract_scalar_value(node_arg) -> float:

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,31 @@ def _get_add_replacement(self, args, meta):
6464

6565
return exir_ops.edge.cortex_m.quantized_add.default, args
6666

67+
def _get_mul_replacement(self, args, meta) -> int:
68+
69+
# Extract values
70+
scale1 = meta["input_qparams"][0].scale
71+
zero_point1 = meta["input_qparams"][0].zp
72+
scale2 = meta["input_qparams"][1].scale
73+
zero_point2 = meta["input_qparams"][1].zp
74+
output_scale = meta["output_qparams"][0].scale
75+
output_zero_point = meta["output_qparams"][0].zp
76+
77+
scale_factor = (scale1 * scale2) / output_scale
78+
output_mult, output_shift = quantize_multiplier_aot(scale_factor)
79+
80+
args = (
81+
args[0],
82+
zero_point1,
83+
args[1],
84+
zero_point2,
85+
output_zero_point,
86+
output_mult,
87+
output_shift,
88+
)
89+
90+
return exir_ops.edge.cortex_m.quantized_mul.default, args
91+
6792
def call_operator(
6893
self,
6994
op: EdgeOpOverload,
@@ -80,6 +105,8 @@ def call_operator(
80105
match op:
81106
case exir_ops.edge.aten.add.Tensor:
82107
op, args = self._get_add_replacement(args, meta)
108+
case exir_ops.edge.aten.mul.Tensor:
109+
op, args = self._get_mul_replacement(args, meta)
83110
case _:
84111
pass
85112

backends/cortex_m/quantizer/operator_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# ----------------- OPERATOR PATTERN PRESETS -----------------
1818
BINARY_OP_PATTERNS = [
1919
[torch.ops.aten.add.Tensor],
20+
[torch.ops.aten.mul.Tensor],
2021
]
2122

2223
LINEAR_OP_PATTERNS = [

backends/cortex_m/quantizer/quantizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
from typing import Callable, List, Optional
88

9-
import torch
10-
119
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1210

1311
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1412
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
1513
from executorch.backends.cortex_m.quantizer.operator_configs import (
14+
BINARY_OP_PATTERNS,
1615
INT8_BINARY_OPS_OPERATOR_CONFIG,
1716
INT8_LINEAR_OPERATOR_CONFIG,
1817
)
@@ -37,7 +36,7 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool:
3736
"""
3837
if node is None:
3938
return False
40-
if node.target not in [torch.ops.aten.add.Tensor]:
39+
if [node.target] not in BINARY_OP_PATTERNS:
4140
return False
4241

4342
if len(node.all_input_nodes) == 2:

backends/cortex_m/test/ops/test_mul.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import pytest
87
import torch
98
from executorch.backends.arm.test.common import parametrize
109
from executorch.backends.cortex_m.test.tester import (
@@ -60,6 +59,16 @@ class CortexMTensorMul(Model):
6059
}
6160

6261

62+
class CortexMTensorMulBroadCast(Model):
63+
ops_before_transforms = {
64+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
65+
}
66+
67+
ops_after_transforms = {
68+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
69+
}
70+
71+
6372
test_cases = {
6473
"self_scalar": McuTestCase(
6574
CortexMSelfMul(),
@@ -91,22 +100,22 @@ class CortexMTensorMul(Model):
91100
),
92101
"tensor_scalar": McuTestCase(
93102
CortexMScalarMul(),
94-
(torch.ones(2, 2), 1.0),
103+
(torch.ones(1), 1.0),
95104
),
96105
"scalar_tensor": McuTestCase(
97106
CortexMScalarMul(),
98-
(1000.0, torch.ones(2, 2)),
107+
(1000.0, torch.ones(1)),
99108
),
100109
"broadcast_1": McuTestCase(
101-
CortexMTensorMul(),
110+
CortexMTensorMulBroadCast(),
102111
(torch.ones(1), torch.ones(2, 2, 2, 2)),
103112
),
104113
"broadcast_2": McuTestCase(
105-
CortexMTensorMul(),
114+
CortexMTensorMulBroadCast(),
106115
(torch.ones((2, 1, 1, 1)), torch.ones(1)),
107116
),
108117
"broadcast_3": McuTestCase(
109-
CortexMTensorMul(),
118+
CortexMTensorMulBroadCast(),
110119
(
111120
ramp_tensor(-2, 2, (2, 1, 2, 1)),
112121
ramp_tensor(-5, 5, (1, 2, 1, 2)),
@@ -115,17 +124,23 @@ class CortexMTensorMul(Model):
115124
}
116125

117126

118-
@pytest.mark.skip(reason="Not implemented yet")
119-
@parametrize("test_case", test_cases)
127+
xfail_cases = {
128+
"self_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars",
129+
"scalar_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars",
130+
}
131+
132+
133+
@parametrize("test_case", test_cases, xfails=xfail_cases)
120134
def test_dialect_mul(test_case):
121135
tester = CortexMTester(test_case.model, test_case.example_inputs)
122136
tester.test_dialect(
123-
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
137+
test_case.model.ops_before_transforms,
138+
test_case.model.ops_after_transforms,
139+
qtol=1,
124140
)
125141

126142

127-
@pytest.mark.skip(reason="Not implemented yet")
128-
@parametrize("test_case", test_cases)
143+
@parametrize("test_case", test_cases, xfails=xfail_cases)
129144
def test_implementation_mul(test_case):
130145
tester = CortexMTester(test_case.model, test_case.example_inputs)
131-
tester.test_implementation()
146+
tester.test_implementation(qtol=1)

0 commit comments

Comments
 (0)