Skip to content

Commit 4daab85

Browse files
authored
Revert "Cortex_m backend: Simplify add + linear fusion passes (#15526)"
This reverts commit 9843222.
1 parent 3405317 commit 4daab85

File tree

9 files changed

+1257
-368
lines changed

9 files changed

+1257
-368
lines changed

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ using Tensor = torch::executor::Tensor;
2222
using ScalarType = executorch::aten::ScalarType;
2323
using Scalar = torch::executor::Scalar;
2424
using Error = executorch::runtime::Error;
25-
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;
2625

2726
// From arm_nn_math_types.h
2827
#define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL))
Lines changed: 128 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4-
* Copyright 2025 Arm Limited and/or its affiliates.
54
*
65
* This source code is licensed under the BSD-style license found in the
76
* LICENSE file in the root directory of this source tree.
87
*/
98

9+
#include "cmsis_scratch_buffer_context.h"
1010
#include "cortex_m_ops_common.h"
1111

1212
extern "C" {
@@ -20,91 +20,152 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2020
Tensor& quantized_linear_out(
2121
KernelRuntimeContext& context,
2222
const Tensor& input,
23+
const Scalar& input_zero_point,
24+
const Scalar& input_multiplier,
25+
const Scalar& input_shift,
2326
const Tensor& weights,
27+
const Tensor& weight_zero_point,
28+
const Tensor& weight_multiplier,
29+
const Tensor& weight_shift,
2430
const torch::executor::optional<Tensor>& bias,
25-
const torch::executor::optional<Tensor>& kernel_sum,
26-
const Scalar& input_offset,
27-
const Scalar& filter_offset,
28-
const Scalar& output_offset,
29-
const IntArrayRef requantize_multipliers,
30-
const IntArrayRef requantize_shifts,
31-
const Scalar& activation_max,
32-
const Scalar& activation_min,
31+
const Tensor& bias_multiplier,
32+
const Tensor& bias_shift,
33+
const Tensor& scratch_buffer,
34+
const Scalar& output_zero_point,
35+
const Scalar& in_features,
36+
const Scalar& out_features,
3337
Tensor& out) {
3438
ET_LOG(Info, "quantized_linear_out: called");
39+
validate_cmsis_nn_tensor_requirements(input, weights, out);
40+
41+
ET_CHECK_MSG(
42+
scratch_buffer.scalar_type() == ScalarType::Char,
43+
"Scratch buffer must be int8");
44+
45+
const int32_t batch_size = input.size(0);
46+
const int32_t in_feat = static_cast<int32_t>(in_features.to<int64_t>());
47+
const int32_t out_feat = static_cast<int32_t>(out_features.to<int64_t>());
48+
const int32_t input_zp = static_cast<int32_t>(input_zero_point.to<int64_t>());
49+
const int32_t output_zp =
50+
static_cast<int32_t>(output_zero_point.to<int64_t>());
51+
const bool is_per_channel = (weight_zero_point.numel() > 1);
3552

3653
const int8_t* input_data = input.const_data_ptr<int8_t>();
3754
const int8_t* weight_data = weights.const_data_ptr<int8_t>();
3855
const int32_t* bias_data =
3956
bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr;
40-
int32_t* kernel_sum_data =
41-
kernel_sum.has_value() ? kernel_sum.value().data_ptr<int32_t>() : nullptr;
4257
int8_t* output_data = out.mutable_data_ptr<int8_t>();
58+
const int32_t* weight_zp_data = weight_zero_point.const_data_ptr<int32_t>();
59+
const int32_t* weight_mult_data = weight_multiplier.const_data_ptr<int32_t>();
60+
const int32_t* weight_shift_data = weight_shift.const_data_ptr<int32_t>();
61+
62+
if (!validate_per_channel_quant_params(
63+
weight_mult_data, weight_shift_data, out_feat)) {
64+
context.fail(Error::InvalidArgument);
65+
return out;
66+
}
67+
68+
// Initialize scratch buffer context (validates early)
69+
CMSISScratchBufferContext scratch_ctx(
70+
const_cast<Tensor&>(scratch_buffer), weights, weight_zero_point, bias);
4371

44-
cmsis_nn_context ctx;
45-
ctx.size = 0; // Not used in CMSIS-NN
46-
ctx.buf = kernel_sum_data;
72+
scratch_ctx.compute_kernel_sums_if_needed();
73+
cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx();
4774

4875
// Setup CMSIS-NN parameters
4976
cmsis_nn_fc_params fc_params;
50-
fc_params.input_offset = static_cast<int32_t>(input_offset.to<int64_t>());
51-
fc_params.filter_offset = static_cast<int32_t>(filter_offset.to<int64_t>());
52-
fc_params.output_offset = static_cast<int32_t>(output_offset.to<int64_t>());
53-
fc_params.activation.min = static_cast<int32_t>(activation_min.to<int64_t>());
54-
fc_params.activation.max = static_cast<int32_t>(activation_max.to<int64_t>());
55-
56-
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
57-
per_tensor_quant_params.multiplier =
58-
static_cast<int32_t>(requantize_multipliers.at(0));
59-
per_tensor_quant_params.shift = static_cast<int32_t>(requantize_shifts.at(0));
60-
61-
auto in_feat = input.size(input.dim() - 1);
62-
auto out_feat = out.size(out.dim() - 1);
63-
auto batches = 1;
64-
for (size_t i = 0; i < input.dim() - 1; i++) {
65-
batches *= input.size(i);
66-
}
67-
ET_LOG(
68-
Info,
69-
"in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d",
70-
in_feat,
71-
out_feat,
72-
batches,
73-
kernel_sum.has_value() ? kernel_sum.value().numel() : 0);
74-
ET_LOG(
75-
Info,
76-
"kernel_sum[0]: %d, kernel_sum[1]: %d",
77-
kernel_sum_data != nullptr ? kernel_sum_data[0] : -1,
78-
kernel_sum_data != nullptr ? kernel_sum_data[1] : -1);
79-
cmsis_nn_dims input_dims = {batches, 1, 1, in_feat};
77+
fc_params.input_offset = -input_zp;
78+
fc_params.output_offset = output_zp;
79+
fc_params.activation.min = std::numeric_limits<int8_t>::min();
80+
fc_params.activation.max = std::numeric_limits<int8_t>::max();
81+
82+
cmsis_nn_dims input_dims = {1, 1, 1, in_feat};
8083
cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat};
8184
cmsis_nn_dims bias_dims = {1, 1, 1, out_feat};
82-
cmsis_nn_dims output_dims = {batches, 1, 1, out_feat};
83-
84-
arm_cmsis_nn_status status = arm_fully_connected_s8(
85-
&ctx,
86-
&fc_params,
87-
&per_tensor_quant_params,
88-
&input_dims,
89-
input_data,
90-
&filter_dims,
91-
weight_data,
92-
&bias_dims,
93-
bias_data,
94-
&output_dims,
95-
output_data);
96-
97-
if (status != ARM_CMSIS_NN_SUCCESS) {
98-
ET_LOG(
99-
Error,
100-
"quantized_linear_out: CMSIS-NN failed with status [%d]",
101-
status);
102-
context.fail(Error::Internal);
103-
return out;
104-
}
85+
cmsis_nn_dims output_dims = {1, 1, 1, out_feat};
86+
87+
arm_cmsis_nn_status status;
88+
for (int32_t b = 0; b < batch_size; b++) {
89+
const int8_t* batch_input = input_data + b * in_feat;
90+
int8_t* batch_output = output_data + b * out_feat;
10591

92+
ET_CHECK_MSG(
93+
batch_input != nullptr && weight_data != nullptr,
94+
"Null input pointers");
95+
ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions");
96+
97+
if (is_per_channel) {
98+
cmsis_nn_per_channel_quant_params per_channel_quant_params;
99+
per_channel_quant_params.multiplier =
100+
const_cast<int32_t*>(weight_mult_data);
101+
per_channel_quant_params.shift = const_cast<int32_t*>(weight_shift_data);
102+
103+
status = arm_fully_connected_per_channel_s8(
104+
&ctx,
105+
&fc_params,
106+
&per_channel_quant_params,
107+
&input_dims,
108+
batch_input,
109+
&filter_dims,
110+
weight_data,
111+
&bias_dims,
112+
bias_data,
113+
&output_dims,
114+
batch_output);
115+
} else {
116+
fc_params.filter_offset = -weight_zp_data[0];
117+
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
118+
per_tensor_quant_params.multiplier = weight_mult_data[0];
119+
per_tensor_quant_params.shift = weight_shift_data[0];
120+
121+
status = arm_fully_connected_s8(
122+
&ctx,
123+
&fc_params,
124+
&per_tensor_quant_params,
125+
&input_dims,
126+
batch_input,
127+
&filter_dims,
128+
weight_data,
129+
&bias_dims,
130+
bias_data,
131+
&output_dims,
132+
batch_output);
133+
}
134+
135+
if (status != ARM_CMSIS_NN_SUCCESS) {
136+
ET_LOG(
137+
Error,
138+
"quantized_linear_out: CMSIS-NN failed with status [%d]",
139+
status);
140+
context.fail(Error::Internal);
141+
return out;
142+
}
143+
}
106144
return out;
107145
}
108146

147+
// Functional variant (stub, not used at runtime)
148+
Tensor quantized_linear(
149+
KernelRuntimeContext& context,
150+
const Tensor& input,
151+
const Scalar& input_zero_point,
152+
const Scalar& input_multiplier,
153+
const Scalar& input_shift,
154+
const Tensor& weights,
155+
const Tensor& weight_zero_point,
156+
const Tensor& weight_multiplier,
157+
const Tensor& weight_shift,
158+
const torch::executor::optional<Tensor>& bias,
159+
const Tensor& bias_multiplier,
160+
const Tensor& bias_shift,
161+
const Tensor& scratch_buffer,
162+
const Scalar& output_zero_point,
163+
const Scalar& in_features,
164+
const Scalar& out_features) {
165+
ET_LOG(Info, "quantized_linear: called");
166+
assert(false);
167+
return const_cast<Tensor&>(input);
168+
}
169+
109170
} // namespace native
110171
} // namespace cortex_m

0 commit comments

Comments
 (0)