Skip to content

Commit f9d4254

Browse files
authored
Merge branch 'main' into export-D87519599
2 parents c406eab + d2c011e commit f9d4254

27 files changed

+1264
-169
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ set(_cortex_m_kernels__srcs
5656
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
5757
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
59+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp
5960
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
6061
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6162
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
extern "C" {
11+
#include "arm_nnfunctions.h"
12+
}
13+
14+
namespace cortex_m {
15+
namespace native {
16+
17+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
18+
19+
namespace {
20+
constexpr int64_t kConvDim = 4;
21+
22+
bool validate_conv2d_arguments(
23+
KernelRuntimeContext& context,
24+
const Tensor& input,
25+
const Tensor& weight,
26+
const torch::executor::optional<Tensor>& bias,
27+
const Tensor& output,
28+
const IntArrayRef& stride,
29+
const IntArrayRef& padding,
30+
const IntArrayRef& dilation,
31+
const Tensor& requantize_multipliers,
32+
const Tensor& requantize_shifts) {
33+
if (input.dim() != kConvDim || weight.dim() != kConvDim ||
34+
output.dim() != kConvDim) {
35+
ET_LOG(Error, "quantized_conv2d_out: tensors must be 4-D");
36+
context.fail(Error::InvalidArgument);
37+
return false;
38+
}
39+
40+
// Check for channels_last dim_order (NHWC: 0, 2, 3, 1)
41+
// Skip check if channels == 1, as dim_order is ambiguous in that case
42+
constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = {
43+
0, 2, 3, 1};
44+
executorch::aten::ArrayRef<executorch::aten::DimOrderType>
45+
channels_last_order(kChannelsLastDimOrder, 4);
46+
47+
if (input.size(1) > 1 && input.dim_order() != channels_last_order) {
48+
ET_LOG(
49+
Error,
50+
"quantized_conv2d_out: input must have channels_last dim_order (NHWC)");
51+
context.fail(Error::InvalidArgument);
52+
return false;
53+
}
54+
55+
if (output.size(1) > 1 && output.dim_order() != channels_last_order) {
56+
ET_LOG(
57+
Error,
58+
"quantized_conv2d_out: output must have channels_last dim_order (NHWC)");
59+
context.fail(Error::InvalidArgument);
60+
return false;
61+
}
62+
63+
if (input.scalar_type() != ScalarType::Char ||
64+
output.scalar_type() != ScalarType::Char) {
65+
ET_LOG(Error, "quantized_conv2d_out: input and output must be int8");
66+
context.fail(Error::InvalidArgument);
67+
return false;
68+
}
69+
70+
if (weight.scalar_type() != ScalarType::Char) {
71+
ET_LOG(Error, "quantized_conv2d_out: weight must be int8");
72+
context.fail(Error::InvalidArgument);
73+
return false;
74+
}
75+
76+
if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) {
77+
ET_LOG(Error, "quantized_conv2d_out: bias must be int32 if provided");
78+
context.fail(Error::InvalidArgument);
79+
return false;
80+
}
81+
82+
if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) {
83+
ET_LOG(
84+
Error,
85+
"quantized_conv2d_out: stride, padding, and dilation must have length 2");
86+
context.fail(Error::InvalidArgument);
87+
return false;
88+
}
89+
90+
const int64_t out_channels = output.size(1);
91+
if (requantize_multipliers.size(0) != out_channels ||
92+
requantize_shifts.size(0) != out_channels) {
93+
ET_LOG(
94+
Error,
95+
"quantized_conv2d_out: per-channel params must match output channels (%zd)",
96+
out_channels);
97+
context.fail(Error::InvalidArgument);
98+
return false;
99+
}
100+
101+
return true;
102+
}
103+
} // namespace
104+
105+
Tensor& quantized_conv2d_out(
106+
KernelRuntimeContext& context,
107+
const Tensor& input,
108+
const Tensor& weight,
109+
const torch::executor::optional<Tensor>& bias,
110+
const IntArrayRef stride,
111+
const IntArrayRef padding,
112+
const IntArrayRef dilation,
113+
const int64_t input_offset,
114+
const int64_t output_offset,
115+
const Tensor& requantize_multipliers,
116+
const Tensor& requantize_shifts,
117+
const int64_t activation_min,
118+
const int64_t activation_max,
119+
Tensor& out) {
120+
if (!validate_conv2d_arguments(
121+
context,
122+
input,
123+
weight,
124+
bias,
125+
out,
126+
stride,
127+
padding,
128+
dilation,
129+
requantize_multipliers,
130+
requantize_shifts)) {
131+
return out;
132+
}
133+
134+
const int32_t batch = static_cast<int32_t>(input.size(0));
135+
const int32_t input_channels = static_cast<int32_t>(input.size(1));
136+
const int32_t input_height = static_cast<int32_t>(input.size(2));
137+
const int32_t input_width = static_cast<int32_t>(input.size(3));
138+
139+
const int32_t kernel_output_channels = static_cast<int32_t>(weight.size(0));
140+
const int32_t kernel_height = static_cast<int32_t>(weight.size(1));
141+
const int32_t kernel_width = static_cast<int32_t>(weight.size(2));
142+
const int32_t kernel_input_channels = static_cast<int32_t>(weight.size(3));
143+
144+
const int32_t output_channels = static_cast<int32_t>(out.size(1));
145+
const int32_t output_height = static_cast<int32_t>(out.size(2));
146+
const int32_t output_width = static_cast<int32_t>(out.size(3));
147+
148+
const int32_t input_offset_val = static_cast<int32_t>(input_offset);
149+
const int32_t output_offset_val = static_cast<int32_t>(output_offset);
150+
const int32_t activation_min_val = static_cast<int32_t>(activation_min);
151+
const int32_t activation_max_val = static_cast<int32_t>(activation_max);
152+
153+
const cmsis_nn_dims input_dims{
154+
batch, input_height, input_width, input_channels};
155+
const cmsis_nn_dims filter_dims{
156+
kernel_output_channels,
157+
kernel_height,
158+
kernel_width,
159+
kernel_input_channels};
160+
const cmsis_nn_dims output_dims{
161+
batch, output_height, output_width, output_channels};
162+
const cmsis_nn_dims bias_dims{1, 1, 1, output_channels};
163+
const cmsis_nn_dims upscale_dims{1, 1, 1, 1};
164+
165+
cmsis_nn_conv_params conv_params;
166+
conv_params.input_offset = input_offset_val;
167+
conv_params.output_offset = output_offset_val;
168+
conv_params.stride.h = static_cast<const int32_t>(stride[0]);
169+
conv_params.stride.w = static_cast<const int32_t>(stride[1]);
170+
conv_params.padding.h = static_cast<const int32_t>(padding[0]);
171+
conv_params.padding.w = static_cast<const int32_t>(padding[1]);
172+
conv_params.dilation.h = static_cast<const int32_t>(dilation[0]);
173+
conv_params.dilation.w = static_cast<const int32_t>(dilation[1]);
174+
conv_params.activation.min = activation_min_val;
175+
conv_params.activation.max = activation_max_val;
176+
177+
cmsis_nn_per_channel_quant_params quant_params;
178+
quant_params.multiplier = requantize_multipliers.data_ptr<int32_t>();
179+
quant_params.shift = requantize_shifts.data_ptr<int32_t>();
180+
181+
const int8_t* input_data = input.const_data_ptr<int8_t>();
182+
const int8_t* weight_data = weight.const_data_ptr<int8_t>();
183+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
184+
const int32_t* bias_data =
185+
bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr;
186+
187+
cmsis_nn_context cmsis_context;
188+
cmsis_context.buf = nullptr;
189+
cmsis_context.size = 0;
190+
191+
const size_t buffer_bytes = static_cast<size_t>(
192+
arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims));
193+
if (buffer_bytes > 0) {
194+
auto buffer_or_error =
195+
context.allocate_temp(buffer_bytes, alignof(int16_t));
196+
if (!buffer_or_error.ok()) {
197+
if (buffer_or_error.error() != Error::NotFound) {
198+
ET_LOG(
199+
Error,
200+
"quantized_conv2d_out: failed to allocate scratch buffer (%d)",
201+
static_cast<int>(buffer_or_error.error()));
202+
context.fail(buffer_or_error.error());
203+
return out;
204+
}
205+
} else {
206+
cmsis_context.buf = buffer_or_error.get();
207+
cmsis_context.size = buffer_bytes;
208+
}
209+
}
210+
211+
const arm_cmsis_nn_status status = arm_convolve_wrapper_s8(
212+
&cmsis_context,
213+
&conv_params,
214+
&quant_params,
215+
&input_dims,
216+
input_data,
217+
&filter_dims,
218+
weight_data,
219+
&bias_dims,
220+
bias_data,
221+
&output_dims,
222+
output_data);
223+
224+
if (status != ARM_CMSIS_NN_SUCCESS) {
225+
ET_LOG(
226+
Error,
227+
"quantized_conv2d_out: arm_convolve_s8 failed with status %d",
228+
status);
229+
context.fail(Error::Internal);
230+
}
231+
232+
return out;
233+
}
234+
235+
} // namespace native
236+
} // namespace cortex_m

0 commit comments

Comments
 (0)