Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/cortex_m/ops/cmsis_scratch_buffer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CMSISScratchBufferContext final {
Tensor& scratch_buffer,
const Tensor& weights,
const Tensor& weight_zero_point,
const ::std::optional<Tensor>& bias)
const torch::executor::optional<Tensor>& bias)
: scratch_ptr_(scratch_buffer.mutable_data_ptr<int8_t>()),
total_size_(scratch_buffer.size(0)),
base_ptr_(reinterpret_cast<uint8_t*>(scratch_ptr_)),
Expand Down
1 change: 1 addition & 0 deletions backends/cortex_m/ops/cortex_m_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Tensor = torch::executor::Tensor;
using ScalarType = executorch::aten::ScalarType;
using Scalar = torch::executor::Scalar;
using Error = executorch::runtime::Error;
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;

// From arm_nn_math_types.h
#define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL))
Expand Down
197 changes: 68 additions & 129 deletions backends/cortex_m/ops/op_quantized_linear.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* Copyright 2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "cmsis_scratch_buffer_context.h"
#include "cortex_m_ops_common.h"

extern "C" {
Expand All @@ -20,151 +20,90 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
Tensor& quantized_linear_out(
KernelRuntimeContext& context,
const Tensor& input,
const Scalar& input_zero_point,
const Scalar& input_multiplier,
const Scalar& input_shift,
const Tensor& weights,
const Tensor& weight_zero_point,
const Tensor& weight_multiplier,
const Tensor& weight_shift,
const ::std::optional<Tensor>& bias,
const Tensor& bias_multiplier,
const Tensor& bias_shift,
const Tensor& scratch_buffer,
const Scalar& output_zero_point,
const Scalar& in_features,
const Scalar& out_features,
const torch::executor::optional<Tensor>& bias,
const torch::executor::optional<Tensor>& kernel_sum,
const Scalar& input_offset,
const Scalar& filter_offset,
const Scalar& output_offset,
const IntArrayRef requantize_multipliers,
const IntArrayRef requantize_shifts,
const Scalar& activation_max,
const Scalar& activation_min,
Tensor& out) {
ET_LOG(Info, "quantized_linear_out: called");
validate_cmsis_nn_tensor_requirements(input, weights, out);

ET_CHECK_MSG(
scratch_buffer.scalar_type() == ScalarType::Char,
"Scratch buffer must be int8");

const int32_t batch_size = input.size(0);
const int32_t in_feat = static_cast<int32_t>(in_features.to<int64_t>());
const int32_t out_feat = static_cast<int32_t>(out_features.to<int64_t>());
const int32_t input_zp = static_cast<int32_t>(input_zero_point.to<int64_t>());
const int32_t output_zp =
static_cast<int32_t>(output_zero_point.to<int64_t>());
const bool is_per_channel = (weight_zero_point.numel() > 1);

const int8_t* input_data = input.const_data_ptr<int8_t>();
const int8_t* weight_data = weights.const_data_ptr<int8_t>();
const int32_t* bias_data =
bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr;
int32_t* kernel_sum_data =
kernel_sum.has_value() ? kernel_sum.value().data_ptr<int32_t>() : nullptr;
int8_t* output_data = out.mutable_data_ptr<int8_t>();
const int32_t* weight_zp_data = weight_zero_point.const_data_ptr<int32_t>();
const int32_t* weight_mult_data = weight_multiplier.const_data_ptr<int32_t>();
const int32_t* weight_shift_data = weight_shift.const_data_ptr<int32_t>();

if (!validate_per_channel_quant_params(
weight_mult_data, weight_shift_data, out_feat)) {
context.fail(Error::InvalidArgument);
return out;
}

// Initialize scratch buffer context (validates early)
CMSISScratchBufferContext scratch_ctx(
const_cast<Tensor&>(scratch_buffer), weights, weight_zero_point, bias);

scratch_ctx.compute_kernel_sums_if_needed();
cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx();
cmsis_nn_context ctx;
ctx.size = 0; // Not used in CMSIS-NN
ctx.buf = kernel_sum_data;

// Setup CMSIS-NN parameters
cmsis_nn_fc_params fc_params;
fc_params.input_offset = -input_zp;
fc_params.output_offset = output_zp;
fc_params.activation.min = std::numeric_limits<int8_t>::min();
fc_params.activation.max = std::numeric_limits<int8_t>::max();

cmsis_nn_dims input_dims = {1, 1, 1, in_feat};
fc_params.input_offset = static_cast<int32_t>(input_offset.to<int64_t>());
fc_params.filter_offset = static_cast<int32_t>(filter_offset.to<int64_t>());
fc_params.output_offset = static_cast<int32_t>(output_offset.to<int64_t>());
fc_params.activation.min = static_cast<int32_t>(activation_min.to<int64_t>());
fc_params.activation.max = static_cast<int32_t>(activation_max.to<int64_t>());

cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
per_tensor_quant_params.multiplier =
static_cast<int32_t>(requantize_multipliers.at(0));
per_tensor_quant_params.shift = static_cast<int32_t>(requantize_shifts.at(0));

auto in_feat = input.size(input.dim() - 1);
auto out_feat = out.size(out.dim() - 1);
auto batches = 1;
for (size_t i = 0; i < input.dim() - 1; i++) {
batches *= input.size(i);
}
ET_LOG(
Info,
"in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d",
in_feat,
out_feat,
batches,
kernel_sum.has_value() ? kernel_sum.value().numel() : 0);
ET_LOG(
Info,
"kernel_sum[0]: %d, kernel_sum[1]: %d",
kernel_sum_data != nullptr ? kernel_sum_data[0] : -1,
kernel_sum_data != nullptr ? kernel_sum_data[1] : -1);
cmsis_nn_dims input_dims = {batches, 1, 1, in_feat};
cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat};
cmsis_nn_dims bias_dims = {1, 1, 1, out_feat};
cmsis_nn_dims output_dims = {1, 1, 1, out_feat};

arm_cmsis_nn_status status;
for (int32_t b = 0; b < batch_size; b++) {
const int8_t* batch_input = input_data + b * in_feat;
int8_t* batch_output = output_data + b * out_feat;

ET_CHECK_MSG(
batch_input != nullptr && weight_data != nullptr,
"Null input pointers");
ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions");

if (is_per_channel) {
cmsis_nn_per_channel_quant_params per_channel_quant_params;
per_channel_quant_params.multiplier =
const_cast<int32_t*>(weight_mult_data);
per_channel_quant_params.shift = const_cast<int32_t*>(weight_shift_data);

status = arm_fully_connected_per_channel_s8(
&ctx,
&fc_params,
&per_channel_quant_params,
&input_dims,
batch_input,
&filter_dims,
weight_data,
&bias_dims,
bias_data,
&output_dims,
batch_output);
} else {
fc_params.filter_offset = -weight_zp_data[0];
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
per_tensor_quant_params.multiplier = weight_mult_data[0];
per_tensor_quant_params.shift = weight_shift_data[0];

status = arm_fully_connected_s8(
&ctx,
&fc_params,
&per_tensor_quant_params,
&input_dims,
batch_input,
&filter_dims,
weight_data,
&bias_dims,
bias_data,
&output_dims,
batch_output);
}

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_linear_out: CMSIS-NN failed with status [%d]",
status);
context.fail(Error::Internal);
return out;
}
cmsis_nn_dims output_dims = {batches, 1, 1, out_feat};

arm_cmsis_nn_status status = arm_fully_connected_s8(
&ctx,
&fc_params,
&per_tensor_quant_params,
&input_dims,
input_data,
&filter_dims,
weight_data,
&bias_dims,
bias_data,
&output_dims,
output_data);

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_linear_out: CMSIS-NN failed with status [%d]",
status);
context.fail(Error::Internal);
return out;
}
return out;
}

// Functional variant (stub, not used at runtime)
Tensor quantized_linear(
KernelRuntimeContext& context,
const Tensor& input,
const Scalar& input_zero_point,
const Scalar& input_multiplier,
const Scalar& input_shift,
const Tensor& weights,
const Tensor& weight_zero_point,
const Tensor& weight_multiplier,
const Tensor& weight_shift,
const ::std::optional<Tensor>& bias,
const Tensor& bias_multiplier,
const Tensor& bias_shift,
const Tensor& scratch_buffer,
const Scalar& output_zero_point,
const Scalar& in_features,
const Scalar& out_features) {
ET_LOG(Info, "quantized_linear: called");
assert(false);
return const_cast<Tensor&>(input);
return out;
}

} // namespace native
Expand Down
Loading
Loading