Skip to content

Commit 660ec64

Browse files
author
Github Executorch
committed
Summary: Add Statefull FC Cortex-m linearOps
Integrate with CMSIS-NN with per-channel quantization support Test Plan: Run e2e test on FVP simulator ./examples/arm/run_mcu_models_fvp.sh --target=cortex-m55 --models=qlinear Reviewers: Subscribers: Tasks: Tags:
1 parent 0e9d871 commit 660ec64

File tree

9 files changed

+1161
-38
lines changed

9 files changed

+1161
-38
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,59 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2121
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
2222
include(FetchContent)
2323

24-
# CMSIS-NN version to download
25-
set(CMSIS_NN_VERSION
26-
"v4.1.0"
27-
CACHE STRING "CMSIS-NN version to download"
28-
)
29-
30-
# Declare CMSIS-NN as a FetchContent project
31-
FetchContent_Declare(
32-
cmsis_nn
33-
GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git
34-
GIT_TAG ${CMSIS_NN_VERSION}
35-
)
36-
37-
# Download and make CMSIS-NN available
38-
FetchContent_MakeAvailable(cmsis_nn)
39-
40-
# Print paths for debugging
41-
message(STATUS "CMSIS-NN source dir: ${cmsis_nn_SOURCE_DIR}")
42-
message(STATUS "CMSIS-NN binary dir: ${cmsis_nn_BINARY_DIR}")
24+
set(USE_LOCAL_CMSIS_NN OFF)
25+
26+
if(USE_LOCAL_CMSIS_NN)
27+
if(NOT EXISTS ${CMSIS_NN_LOCAL_PATH})
28+
message(
29+
FATAL_ERROR "CMSIS-NN local path does not exist: ${CMSIS_NN_LOCAL_PATH}"
30+
)
31+
endif()
32+
if(NOT EXISTS ${CMSIS_NN_LOCAL_LIB})
33+
message(
34+
FATAL_ERROR "CMSIS-NN local lib does not exist: ${CMSIS_NN_LOCAL_LIB}"
35+
)
36+
endif()
37+
message(STATUS "Using CMSIS-NN from: ${CMSIS_NN_LOCAL_PATH}")
38+
add_subdirectory(${CMSIS_NN_LOCAL_PATH} cmsis_nn_build)
39+
# Add MVEI define to cmsis-nn target
40+
target_compile_definitions(cmsis-nn PUBLIC ARM_MATH_MVEI=1)
41+
42+
set(CMSIS_NN_INCLUDE_DIR "${CMSIS_NN_LOCAL_PATH}/Include")
43+
set(CMSIS_NN_LIB "${CMSIS_NN_LOCAL_LIB}/libcmsis-nn.a")
44+
45+
else()
46+
message(STATUS "Using CMSIS-NN from via : FetchContent")
47+
set(CMSIS_NN_VERSION
48+
"v7.0.0"
49+
CACHE STRING "CMSIS-NN version to download"
50+
)
51+
FetchContent_Declare(
52+
cmsis_nn
53+
GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git
54+
GIT_TAG ${CMSIS_NN_VERSION}
55+
)
56+
FetchContent_MakeAvailable(cmsis_nn)
57+
# Add MVEI define to cmsis-nn target
58+
target_compile_definitions(cmsis-nn PUBLIC ARM_MATH_MVEI=1)
59+
# Get the correct source and binary dirs
60+
FetchContent_GetProperties(
61+
cmsis_nn
62+
SOURCE_DIR CMSIS_NN_SOURCE_DIR
63+
BINARY_DIR CMSIS_NN_BINARY_DIR
64+
)
65+
set(CMSIS_NN_INCLUDE_DIR "${CMSIS_NN_SOURCE_DIR}/Include")
66+
set(CMSIS_NN_LIB "${CMSIS_NN_BINARY_DIR}/libcmsis-nn.a")
67+
message(STATUS " ${CMSIS_NN_SOURCE_DIR}")
68+
message(STATUS "CMSIS-NN binary dir: ${CMSIS_NN_BINARY_DIR}")
69+
endif()
4370

4471
# Cortex-M ops kernel sources
4572
set(_cortex_m_kernels__srcs
4673
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
4774
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
4875
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
76+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
4977
)
5078

5179
# Generate C++ bindings to register kernels into Executorch (for runtime)
@@ -66,13 +94,11 @@ target_include_directories(
6694
cortex_m_kernels
6795
PRIVATE ${EXECUTORCH_ROOT}/..
6896
${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
69-
${cmsis_nn_SOURCE_DIR}/Include
97+
${CMSIS_NN_INCLUDE_DIR}
7098
)
7199

72100
# Link directly to the CMSIS-NN static library file
73-
target_link_libraries(
74-
cortex_m_kernels PUBLIC ${cmsis_nn_BINARY_DIR}/libcmsis-nn.a executorch
75-
)
101+
target_link_libraries(cortex_m_kernels PUBLIC ${CMSIS_NN_LIB} executorch)
76102

77103
# Add dependency to ensure CMSIS-NN builds before we try to link. Use the actual
78104
# CMSIS-NN target name (usually 'cmsis-nn')

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ inline void validate_cmsis_nn_tensor_requirements(
3232
// Basic dtype validation
3333
ET_CHECK_MSG(
3434
input1.scalar_type() == expected_dtype,
35-
"Input1 dtype must be %hhd",
36-
expected_dtype);
35+
"Input1 dtype must be %hhd, got %hhd",
36+
expected_dtype,
37+
input1.scalar_type());
3738
ET_CHECK_MSG(
3839
input2.scalar_type() == expected_dtype,
39-
"Input2 dtype must be %hhd",
40-
expected_dtype);
40+
"Input2 dtype must be %hhd, got %hhd",
41+
expected_dtype,
42+
input2.scalar_type());
4143
ET_CHECK_MSG(
4244
output.scalar_type() == expected_dtype,
43-
"Output dtype must be %hhd",
44-
expected_dtype);
45+
"Output dtype must be %hhd, got %hhd",
46+
expected_dtype,
47+
output.scalar_type());
4548

4649
// Dim order consistency
4750
ET_CHECK_MSG(
@@ -114,6 +117,26 @@ inline void validate_quantization_params(
114117
"Single quant Output");
115118
}
116119

120+
inline void validate_per_channel_quant_params(
121+
const int32_t* multipliers,
122+
const int32_t* shifts,
123+
int num_channels) {
124+
for (int i = 0; i < num_channels; ++i) {
125+
if (multipliers[i] < (1LL << 30) || multipliers[i] > ((1LL << 31) - 1)) {
126+
ET_LOG(
127+
Error,
128+
"weight_multiplier[%d] out of CMSIS-NN range: %d",
129+
i,
130+
multipliers[i]);
131+
return;
132+
}
133+
if (shifts[i] < 0 || shifts[i] > 31) {
134+
ET_LOG(Error, "weight_shift[%d] out of range: %d", i, shifts[i]);
135+
return;
136+
}
137+
}
138+
}
139+
117140
inline Error resize_to_broadcast_target_size(
118141
const Tensor& input1,
119142
const Tensor& input2,
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "cortex_m_ops_common.h"
10+
11+
// Include CMSIS-NN headers with C linkage
12+
extern "C" {
13+
#include "arm_nnfunctions.h"
14+
}
15+
16+
namespace cortex_m {
17+
namespace native {
18+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
19+
20+
struct kernel_sum_state {
21+
bool updated = false;
22+
char buffer[2048] = {0};
23+
};
24+
25+
Tensor& quantized_linear_out(
26+
KernelRuntimeContext& context,
27+
const Tensor& input,
28+
const Scalar& input_zero_point,
29+
const Scalar& input_multiplier,
30+
const Scalar& input_shift,
31+
const Tensor& weights,
32+
const Tensor& weight_zero_point,
33+
const Tensor& weight_multiplier,
34+
const Tensor& weight_shift,
35+
const torch::executor::optional<Tensor>& bias,
36+
const Tensor& bias_multiplier, // IGNORE - not used
37+
const Tensor& bias_shift, // IGNORE - not used
38+
const Tensor& scratch_buffer,
39+
const Scalar& output_zero_point,
40+
const Scalar& in_features,
41+
const Scalar& out_features,
42+
Tensor& out) {
43+
ET_LOG(Info, "quantized_linear_out: called");
44+
validate_cmsis_nn_tensor_requirements(input, weights, out);
45+
46+
ET_CHECK_MSG(
47+
scratch_buffer.scalar_type() == ScalarType::Char,
48+
"Scratch buffer must be int8");
49+
50+
// --- Parameter Extraction and Validation ---
51+
const int32_t batch_size = input.size(0);
52+
const int32_t in_feat = static_cast<int32_t>(in_features.to<int64_t>());
53+
const int32_t out_feat = static_cast<int32_t>(out_features.to<int64_t>());
54+
int32_t input_zp = static_cast<int32_t>(input_zero_point.to<int64_t>());
55+
int32_t output_zp = static_cast<int32_t>(output_zero_point.to<int64_t>());
56+
bool is_per_channel = (weight_zero_point.numel() > 1);
57+
const int8_t* input_data = input.const_data_ptr<int8_t>();
58+
const int8_t* weight_data = weights.const_data_ptr<int8_t>();
59+
const int32_t* bias_data =
60+
bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr;
61+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
62+
int8_t* scratch_ptr = scratch_buffer.mutable_data_ptr<int8_t>();
63+
const int32_t* weight_zp_data = weight_zero_point.const_data_ptr<int32_t>();
64+
const int32_t* weight_mult_data = weight_multiplier.const_data_ptr<int32_t>();
65+
const int32_t* weight_shift_data = weight_shift.const_data_ptr<int32_t>();
66+
67+
if (!bias.has_value()) {
68+
ET_LOG(Info, "No bias tensor provided (bias_data is nullptr)");
69+
}
70+
validate_per_channel_quant_params(
71+
weight_mult_data, weight_shift_data, out_feat);
72+
73+
cmsis_nn_fc_params fc_params;
74+
fc_params.input_offset = -input_zp;
75+
fc_params.output_offset = output_zp;
76+
fc_params.activation.min = std::numeric_limits<int8_t>::min();
77+
fc_params.activation.max = std::numeric_limits<int8_t>::max();
78+
cmsis_nn_dims input_dims = {1, 1, 1, in_feat};
79+
cmsis_nn_dims filter_dims = {out_feat, 1, 1, in_feat};
80+
cmsis_nn_dims bias_dims = {1, 1, 1, out_feat};
81+
cmsis_nn_dims output_dims = {1, 1, 1, out_feat};
82+
arm_cmsis_nn_status status;
83+
84+
// Pass allocates a flat scratch buffer:
85+
// [------------------- scratch_buffer -----------------------]
86+
// |<- CMSIS-NN workspace ->|<--- kernel_sum_state struct --->|
87+
//
88+
// Buffer pointers:
89+
// ^ ^ ^
90+
// scratch_ptr(start) scratch_ptr + cmsis_scratch scratch_ptr + total_size
91+
//
92+
// - CMSIS-NN workspace: used by CMSIS-NN kernels for temporary data
93+
// - Always give CMSIS-NN the start of the buffer for alignment
94+
// - Place kernel_sum_state structs at the end to avoid breaking alignment
95+
cmsis_nn_context ctx;
96+
kernel_sum_state* state = reinterpret_cast<kernel_sum_state*>(
97+
scratch_ptr + scratch_buffer.size(0) - sizeof(kernel_sum_state));
98+
if (!state->updated) {
99+
int required_bytes = arm_fully_connected_s8_get_buffer_size(&filter_dims);
100+
ET_CHECK_MSG(
101+
(scratch_buffer.size(0) - sizeof(kernel_sum_state) >= required_bytes),
102+
"Scratch buffer size %zu is not enough for kernel sum buffer size %d",
103+
sizeof(state->buffer),
104+
required_bytes);
105+
106+
// Compute kernel sums once
107+
arm_vector_sum_s8(
108+
(int32_t*)scratch_ptr,
109+
in_feat,
110+
out_feat,
111+
weight_data,
112+
weight_zp_data[0],
113+
0, // rhs_offset (int32_t)
114+
nullptr // bias (const int32_t*)
115+
);
116+
state->updated = true;
117+
ET_LOG(
118+
Info,
119+
"Computed kernel sums, stored in state->buffer [required_bytes : %d]",
120+
required_bytes);
121+
}
122+
123+
// start of cmsis buffer
124+
ctx.buf = scratch_ptr;
125+
ctx.size = scratch_buffer.size(0) - sizeof(kernel_sum_state);
126+
127+
for (int32_t b = 0; b < batch_size; b++) {
128+
const int8_t* batch_input = input_data + b * in_feat;
129+
int8_t* batch_output = output_data + b * out_feat;
130+
if (is_per_channel) {
131+
// Per-channel quantization
132+
cmsis_nn_per_channel_quant_params per_channel_quant_params;
133+
per_channel_quant_params.multiplier =
134+
const_cast<int32_t*>(weight_mult_data);
135+
per_channel_quant_params.shift = const_cast<int32_t*>(weight_shift_data);
136+
137+
status = arm_fully_connected_per_channel_s8(
138+
&ctx,
139+
&fc_params,
140+
&per_channel_quant_params,
141+
&input_dims,
142+
batch_input,
143+
&filter_dims,
144+
weight_data,
145+
&bias_dims,
146+
bias_data,
147+
&output_dims,
148+
batch_output);
149+
} else {
150+
// Per-tensor quantization
151+
fc_params.filter_offset = -weight_zp_data[0];
152+
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
153+
per_tensor_quant_params.multiplier = weight_mult_data[0];
154+
per_tensor_quant_params.shift = weight_shift_data[0];
155+
156+
status = arm_fully_connected_s8(
157+
&ctx,
158+
&fc_params,
159+
&per_tensor_quant_params,
160+
&input_dims,
161+
batch_input,
162+
&filter_dims,
163+
weight_data,
164+
&bias_dims,
165+
bias_data,
166+
&output_dims,
167+
batch_output);
168+
}
169+
170+
if (status != ARM_CMSIS_NN_SUCCESS) {
171+
ET_LOG(
172+
Error,
173+
"quantized_linear_out: CMSIS-NN failed with status [%d]",
174+
status);
175+
context.fail(Error::Internal);
176+
return out;
177+
}
178+
}
179+
return out;
180+
}
181+
182+
// Functional variant (stub, not used at runtime)
183+
Tensor quantized_linear(
184+
KernelRuntimeContext& context,
185+
const Tensor& input,
186+
const Scalar& input_zero_point,
187+
const Scalar& input_multiplier,
188+
const Scalar& input_shift,
189+
const Tensor& weights,
190+
const Tensor& weight_zero_point,
191+
const Tensor& weight_multiplier,
192+
const Tensor& weight_shift,
193+
const torch::executor::optional<Tensor>& bias,
194+
const Tensor& bias_multiplier,
195+
const Tensor& bias_shift,
196+
const Tensor& scratch_buffer,
197+
const Scalar& output_zero_point,
198+
const Scalar& in_features,
199+
const Scalar& out_features) {
200+
ET_LOG(Info, "quantized_linear: called");
201+
assert(false);
202+
return const_cast<Tensor&>(input);
203+
}
204+
205+
} // namespace native
206+
} // namespace cortex_m

0 commit comments

Comments
 (0)