Skip to content

Commit c73334d

Browse files
sidartGithub Executorch
authored andcommitted
Summary: Initial CMSS-NN integration for Quantized Add Op
Test Plan: a) Setup for Arm FVP and run 'examples/arm/run.sh' (Check no regressions in e2e test scenarios) b) Then add to run.sh another iteration with qadd with only --quantize flag and see that quantized add op is called c) cd backends/cortex_m/test/; python test_quantize_add_fusion_pass.py ---------------------------------------------------------------------- Ran 8 tests in 11.128s OK Reviewers: Subscribers: Tasks: Tags:
1 parent 310a05d commit c73334d

File tree

12 files changed

+1170
-107
lines changed

12 files changed

+1170
-107
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,41 @@ endif()
2424

2525
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2626
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
27+
include(ExternalProject)
28+
29+
# Download and build CMSIS-NN from GitHub
30+
set(CMSIS_NN_VERSION
31+
"v4.1.0"
32+
CACHE STRING "CMSIS-NN version to download"
33+
)
34+
set(CMSIS_NN_ROOT ${CMAKE_CURRENT_BINARY_DIR}/cmsis-nn)
35+
set(CMSIS_NN_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/cmsis-nn-build)
36+
set(CMSIS_NN_LIB_PATH ${CMSIS_NN_BINARY_DIR}/libcmsis-nn.a)
37+
38+
set(TARGET_CPU
39+
"cortex-m55"
40+
CACHE STRING "Target CPU for CMSIS-NN build"
41+
)
42+
ExternalProject_Add(
43+
cmsis_nn_external
44+
GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git
45+
GIT_TAG ${CMSIS_NN_VERSION}
46+
SOURCE_DIR ${CMSIS_NN_ROOT}
47+
BINARY_DIR ${CMSIS_NN_BINARY_DIR}
48+
CMAKE_ARGS
49+
-DCMAKE_TOOLCHAIN_FILE=${EXECUTORCH_ROOT}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake
50+
-DTARGET_CPU=${TARGET_CPU}
51+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
52+
BUILD_COMMAND ${CMAKE_COMMAND} --build <BINARY_DIR> --parallel
53+
INSTALL_COMMAND ""
54+
BUILD_BYPRODUCTS ${CMSIS_NN_LIB_PATH}
55+
)
2756

2857
# Cortex-M ops kernel sources
2958
set(_cortex_m_kernels__srcs
3059
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
3160
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
3262
)
3363

3464
# Generate C++ bindings to register kernels into Executorch (for runtime). Here
@@ -44,9 +74,23 @@ message("Generated files ${gen_command_sources}")
4474

4575
# Build a library for _cortex_m_kernels_srcs
4676
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
47-
target_link_libraries(cortex_m_kernels PRIVATE executorch)
4877
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})
4978

79+
# Add dependency on CMSIS-NN external project
80+
add_dependencies(cortex_m_kernels cmsis_nn_external)
81+
82+
# Set include directories - Include is directly in CMSIS-NN root
83+
target_include_directories(
84+
cortex_m_kernels
85+
PRIVATE ${EXECUTORCH_ROOT}/..
86+
${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
87+
$<BUILD_INTERFACE:${CMSIS_NN_ROOT}/Include>
88+
$<BUILD_INTERFACE:${CMSIS_NN_ROOT}>
89+
)
90+
91+
# Link against the CMSIS-NN static library directly
92+
target_link_libraries(cortex_m_kernels PUBLIC ${CMSIS_NN_LIB_PATH} executorch)
93+
5094
# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
5195
gen_operators_lib(
5296
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
// Include CMSIS-NN headers with C linkage
15+
extern "C" {
16+
#include "arm_nnfunctions.h"
17+
}
18+
19+
#define VALIDATE_QUANTIZED_INPUTS(ctx, input1, input2, output) \
20+
do { \
21+
ET_CHECK_MSG( \
22+
(input1).scalar_type() == ScalarType::Char, "Input1 must be int8"); \
23+
ET_CHECK_MSG( \
24+
(input2).scalar_type() == ScalarType::Char, "Input2 must be int8"); \
25+
ET_CHECK_MSG( \
26+
(output).scalar_type() == ScalarType::Char, "Output must be int8"); \
27+
ET_KERNEL_CHECK( \
28+
ctx, \
29+
torch::executor::resize_to_broadcast_target_size( \
30+
input1, input2, output) == ::executorch::runtime::Error::Ok, \
31+
InvalidArgument, \
32+
output); \
33+
ET_CHECK_MSG( \
34+
(input1).sizes() == (input2).sizes(), \
35+
"Input tensors must be the same shape"); \
36+
ET_CHECK_MSG( \
37+
(input1).scalar_type() == (input2).scalar_type(), \
38+
"Input tensors must be the same dtype"); \
39+
} while (0)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "cortex_m_ops_common.h"
10+
11+
namespace cortex_m {
12+
namespace native {
13+
14+
using Tensor = torch::executor::Tensor;
15+
using ScalarType = executorch::aten::ScalarType;
16+
using Scalar = torch::executor::Scalar;
17+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
18+
19+
Tensor& quantized_add_out(
20+
KernelRuntimeContext& context,
21+
const Tensor& input1_int8,
22+
const Scalar& input1_zero_point,
23+
const Scalar& input1_multiplier,
24+
const Scalar& input1_shift,
25+
const Tensor& input2_int8,
26+
const Scalar& input2_zero_point,
27+
const Scalar& input2_multiplier,
28+
const Scalar& input2_shift,
29+
const Scalar& output_zero_point,
30+
const Scalar& output_multiplier,
31+
const Scalar& output_shift,
32+
Tensor& out) {
33+
VALIDATE_QUANTIZED_INPUTS(context, input1_int8, input2_int8, out);
34+
35+
ET_LOG(
36+
Info,
37+
"quantized_add_out: input1_int8.sizes() = %zu",
38+
input1_int8.sizes().size());
39+
40+
// FIX: Use template types that ExecutorTorch definitely provides
41+
// Use to<int64_t>() and to<double>() which are commonly instantiated
42+
int32_t zp1 = static_cast<int32_t>(input1_zero_point.to<int64_t>());
43+
int32_t input1_mult = static_cast<int32_t>(input1_multiplier.to<int64_t>());
44+
int input1_shift_val = static_cast<int>(input1_shift.to<int64_t>());
45+
46+
int32_t zp2 = static_cast<int32_t>(input2_zero_point.to<int64_t>());
47+
int32_t input2_mult = static_cast<int32_t>(input2_multiplier.to<int64_t>());
48+
int input2_shift_val = static_cast<int>(input2_shift.to<int64_t>());
49+
50+
int32_t out_zp = static_cast<int32_t>(output_zero_point.to<int64_t>());
51+
int32_t output_mult = static_cast<int32_t>(output_multiplier.to<int64_t>());
52+
int output_shift_val = static_cast<int>(output_shift.to<int64_t>());
53+
54+
// Left shift to maximize precision (tune as needed)
55+
const int32_t left_shift = 20;
56+
const int32_t activation_min = std::numeric_limits<int8_t>::min();
57+
const int32_t activation_max = std::numeric_limits<int8_t>::max();
58+
59+
// Resize output tensor to match input shape
60+
auto err = torch::executor::resize_tensor(out, input1_int8.sizes());
61+
if (err != executorch::runtime::Error::Ok) {
62+
ET_LOG(
63+
Error,
64+
"quantized_add_out: resize_tensor failed with error code [%d]",
65+
static_cast<int>(err));
66+
std::memset(out.mutable_data_ptr<int8_t>(), 0, out.nbytes());
67+
return out;
68+
}
69+
70+
ET_LOG(
71+
Info,
72+
"Using AoT-computed parameters: input1[mult=%d, shift=%d], input2[mult=%d, shift=%d], output[mult=%d, shift=%d]",
73+
input1_mult,
74+
input1_shift_val,
75+
input2_mult,
76+
input2_shift_val,
77+
output_mult,
78+
output_shift_val);
79+
80+
// Call CMSIS-NN kernel with precomputed parameters
81+
arm_cmsis_nn_status status = arm_elementwise_add_s8(
82+
input1_int8.const_data_ptr<int8_t>(),
83+
input2_int8.const_data_ptr<int8_t>(),
84+
static_cast<int32_t>(zp1),
85+
input1_mult,
86+
input1_shift_val,
87+
static_cast<int32_t>(zp2),
88+
input2_mult,
89+
input2_shift_val,
90+
left_shift,
91+
out.mutable_data_ptr<int8_t>(),
92+
static_cast<int32_t>(out_zp),
93+
output_mult,
94+
output_shift_val,
95+
static_cast<int32_t>(out.numel()),
96+
activation_min,
97+
activation_max);
98+
99+
if (status != ARM_CMSIS_NN_SUCCESS) {
100+
ET_LOG(
101+
Error,
102+
"quantized_add_out: arm_elementwise_add_s8 failed with status [%d]",
103+
status);
104+
std::memset(out.mutable_data_ptr<int8_t>(), 0, out.nbytes());
105+
} else {
106+
ET_LOG(
107+
Info,
108+
"quantized_add_out: Successfully completed with AoT-computed parameters!");
109+
}
110+
111+
return out;
112+
}
113+
114+
// Stub Implementation: Non-out variant for compatibility (functional variant)
115+
// EXIR/ExecuTorch runs an out-variant pass that converts
116+
// .default operations to .out variants before memory planning.
117+
// In the pass we are calling quantized_add's default variant
118+
// but ExecuTorch's kernel dispatch mechanism will end up calling the out
119+
// variant. This stub is to make sure that compiler doesn't complain.
120+
Tensor quantized_add(
121+
KernelRuntimeContext& context,
122+
const Tensor& input1_int8,
123+
const Scalar& input1_zero_point,
124+
const Scalar& input1_multiplier,
125+
const Scalar& input1_shift,
126+
const Tensor& input2_int8,
127+
const Scalar& input2_zero_point,
128+
const Scalar& input2_multiplier,
129+
const Scalar& input2_shift,
130+
const Scalar& output_zero_point,
131+
const Scalar& output_multiplier,
132+
const Scalar& output_shift) {
133+
ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes());
134+
return const_cast<Tensor&>(input1_int8); // to make compiler happy
135+
}
136+
137+
} // namespace native
138+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,138 @@ def dequantize_per_tensor_impl(
9696
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
9797
input, scale, zero_point, quant_min, quant_max, dtype
9898
)
99+
100+
101+
# Define the operator schema with multipliers and shifts (11 args)
102+
lib.define(
103+
"quantized_add("
104+
"Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, "
105+
"Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, "
106+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
107+
)
108+
109+
110+
@register_fake("cortex_m::quantized_add")
111+
def quantized_add_meta(
112+
self: torch.Tensor,
113+
self_zero_point: int,
114+
self_multiplier: int,
115+
self_shift: int,
116+
other: torch.Tensor,
117+
other_zero_point: int,
118+
other_multiplier: int,
119+
other_shift: int,
120+
output_zero_point: int,
121+
output_multiplier: int,
122+
output_shift: int,
123+
) -> torch.Tensor:
124+
return torch.empty_like(self, dtype=torch.int8)
125+
126+
127+
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
128+
def quantized_add_impl(
129+
self: torch.Tensor,
130+
self_zero_point: int,
131+
self_multiplier: int,
132+
self_shift: int,
133+
other: torch.Tensor,
134+
other_zero_point: int,
135+
other_multiplier: int,
136+
other_shift: int,
137+
output_zero_point: int,
138+
output_multiplier: int,
139+
output_shift: int,
140+
) -> torch.Tensor:
141+
# For now, convert back to float, add, and quantize (as placeholder)
142+
# Dequantize inputs using multiplier/shift
143+
self_fp = (self.float() - self_zero_point) * (
144+
self_multiplier / (1 << (31 - self_shift))
145+
)
146+
other_fp = (other.float() - other_zero_point) * (
147+
other_multiplier / (1 << (31 - other_shift))
148+
)
149+
150+
# Add
151+
result_fp = self_fp + other_fp
152+
153+
# Quantize output
154+
result_quantized = (
155+
result_fp / (output_multiplier / (1 << (31 - output_shift)))
156+
) + output_zero_point
157+
158+
return result_quantized.clamp(-128, 127).to(torch.int8)
159+
160+
161+
# Define the operator schema with multipliers and shifts (11 args + out tensor)
162+
lib.define(
163+
"quantized_add.out("
164+
"Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, "
165+
"Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, "
166+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
167+
"*, Tensor(a!) out) -> Tensor(a!)"
168+
)
169+
170+
171+
# Fake meta function for shape and dtype inference during compilation
172+
@register_fake("cortex_m::quantized_add.out")
173+
def quantized_add_out_meta(
174+
self: torch.Tensor,
175+
self_zero_point: int,
176+
self_multiplier: int,
177+
self_shift: int,
178+
other: torch.Tensor,
179+
other_zero_point: int,
180+
other_multiplier: int,
181+
other_shift: int,
182+
output_zero_point: int,
183+
output_multiplier: int,
184+
output_shift: int,
185+
out: torch.Tensor,
186+
) -> torch.Tensor:
187+
# Validate shape compatibility if needed
188+
assert out.shape == self.shape, "Output shape must match input shape"
189+
# Output dtype is int8
190+
return out
191+
192+
193+
# Actual implementation delegating to backend or custom kernel
194+
@impl(lib, "quantized_add.out", "CompositeExplicitAutograd")
195+
def quantized_add_out_impl(
196+
self: torch.Tensor,
197+
self_zero_point: int,
198+
self_multiplier: int,
199+
self_shift: int,
200+
other: torch.Tensor,
201+
other_zero_point: int,
202+
other_multiplier: int,
203+
other_shift: int,
204+
output_zero_point: int,
205+
output_multiplier: int,
206+
output_shift: int,
207+
*,
208+
out: torch.Tensor,
209+
) -> torch.Tensor:
210+
# Example placeholder implementation:
211+
# Dequantize inputs using multiplier and shift
212+
self_fp = (self.float() - self_zero_point) * (
213+
self_multiplier / (1 << (31 - self_shift))
214+
)
215+
other_fp = (other.float() - other_zero_point) * (
216+
other_multiplier / (1 << (31 - other_shift))
217+
)
218+
219+
# Add in floating point
220+
result_fp = self_fp + other_fp
221+
222+
# Quantize output using multiplier and shift
223+
result_quantized = (
224+
result_fp / (output_multiplier / (1 << (31 - output_shift)))
225+
) + output_zero_point
226+
227+
# Clamp and convert to int8
228+
result_quantized = result_quantized.clamp(-128, 127).to(torch.int8)
229+
230+
# Write into the provided output tensor
231+
out.copy_(result_quantized)
232+
233+
return out

0 commit comments

Comments
 (0)