11/*
22 * Copyright (c) Meta Platforms, Inc. and affiliates.
33 * All rights reserved.
4- * Copyright 2025 Arm Limited and/or its affiliates.
54 *
65 * This source code is licensed under the BSD-style license found in the
76 * LICENSE file in the root directory of this source tree.
87 */
98
9+ #include " cmsis_scratch_buffer_context.h"
1010#include " cortex_m_ops_common.h"
1111
1212extern " C" {
@@ -20,91 +20,152 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2020Tensor& quantized_linear_out (
2121 KernelRuntimeContext& context,
2222 const Tensor& input,
23+ const Scalar& input_zero_point,
24+ const Scalar& input_multiplier,
25+ const Scalar& input_shift,
2326 const Tensor& weights,
27+ const Tensor& weight_zero_point,
28+ const Tensor& weight_multiplier,
29+ const Tensor& weight_shift,
2430 const torch::executor::optional<Tensor>& bias,
25- const torch::executor::optional<Tensor>& kernel_sum,
26- const Scalar& input_offset,
27- const Scalar& filter_offset,
28- const Scalar& output_offset,
29- const IntArrayRef requantize_multipliers,
30- const IntArrayRef requantize_shifts,
31- const Scalar& activation_max,
32- const Scalar& activation_min,
31+ const Tensor& bias_multiplier,
32+ const Tensor& bias_shift,
33+ const Tensor& scratch_buffer,
34+ const Scalar& output_zero_point,
35+ const Scalar& in_features,
36+ const Scalar& out_features,
3337 Tensor& out) {
3438 ET_LOG (Info, " quantized_linear_out: called" );
39+ validate_cmsis_nn_tensor_requirements (input, weights, out);
40+
41+ ET_CHECK_MSG (
42+ scratch_buffer.scalar_type () == ScalarType::Char,
43+ " Scratch buffer must be int8" );
44+
45+ const int32_t batch_size = input.size (0 );
46+ const int32_t in_feat = static_cast <int32_t >(in_features.to <int64_t >());
47+ const int32_t out_feat = static_cast <int32_t >(out_features.to <int64_t >());
48+ const int32_t input_zp = static_cast <int32_t >(input_zero_point.to <int64_t >());
49+ const int32_t output_zp =
50+ static_cast <int32_t >(output_zero_point.to <int64_t >());
51+ const bool is_per_channel = (weight_zero_point.numel () > 1 );
3552
3653 const int8_t * input_data = input.const_data_ptr <int8_t >();
3754 const int8_t * weight_data = weights.const_data_ptr <int8_t >();
3855 const int32_t * bias_data =
3956 bias.has_value () ? bias.value ().const_data_ptr <int32_t >() : nullptr ;
40- int32_t * kernel_sum_data =
41- kernel_sum.has_value () ? kernel_sum.value ().data_ptr <int32_t >() : nullptr ;
4257 int8_t * output_data = out.mutable_data_ptr <int8_t >();
58+ const int32_t * weight_zp_data = weight_zero_point.const_data_ptr <int32_t >();
59+ const int32_t * weight_mult_data = weight_multiplier.const_data_ptr <int32_t >();
60+ const int32_t * weight_shift_data = weight_shift.const_data_ptr <int32_t >();
61+
62+ if (!validate_per_channel_quant_params (
63+ weight_mult_data, weight_shift_data, out_feat)) {
64+ context.fail (Error::InvalidArgument);
65+ return out;
66+ }
67+
68+ // Initialize scratch buffer context (validates early)
69+ CMSISScratchBufferContext scratch_ctx (
70+ const_cast <Tensor&>(scratch_buffer), weights, weight_zero_point, bias);
4371
44- cmsis_nn_context ctx;
45- ctx.size = 0 ; // Not used in CMSIS-NN
46- ctx.buf = kernel_sum_data;
72+ scratch_ctx.compute_kernel_sums_if_needed ();
73+ cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx ();
4774
4875 // Setup CMSIS-NN parameters
4976 cmsis_nn_fc_params fc_params;
50- fc_params.input_offset = static_cast <int32_t >(input_offset.to <int64_t >());
51- fc_params.filter_offset = static_cast <int32_t >(filter_offset.to <int64_t >());
52- fc_params.output_offset = static_cast <int32_t >(output_offset.to <int64_t >());
53- fc_params.activation .min = static_cast <int32_t >(activation_min.to <int64_t >());
54- fc_params.activation .max = static_cast <int32_t >(activation_max.to <int64_t >());
55-
56- cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
57- per_tensor_quant_params.multiplier =
58- static_cast <int32_t >(requantize_multipliers.at (0 ));
59- per_tensor_quant_params.shift = static_cast <int32_t >(requantize_shifts.at (0 ));
60-
61- auto in_feat = input.size (input.dim () - 1 );
62- auto out_feat = out.size (out.dim () - 1 );
63- auto batches = 1 ;
64- for (size_t i = 0 ; i < input.dim () - 1 ; i++) {
65- batches *= input.size (i);
66- }
67- ET_LOG (
68- Info,
69- " in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d" ,
70- in_feat,
71- out_feat,
72- batches,
73- kernel_sum.has_value () ? kernel_sum.value ().numel () : 0 );
74- ET_LOG (
75- Info,
76- " kernel_sum[0]: %d, kernel_sum[1]: %d" ,
77- kernel_sum_data != nullptr ? kernel_sum_data[0 ] : -1 ,
78- kernel_sum_data != nullptr ? kernel_sum_data[1 ] : -1 );
79- cmsis_nn_dims input_dims = {batches, 1 , 1 , in_feat};
77+ fc_params.input_offset = -input_zp;
78+ fc_params.output_offset = output_zp;
79+ fc_params.activation .min = std::numeric_limits<int8_t >::min ();
80+ fc_params.activation .max = std::numeric_limits<int8_t >::max ();
81+
82+ cmsis_nn_dims input_dims = {1 , 1 , 1 , in_feat};
8083 cmsis_nn_dims filter_dims = {in_feat, 1 , 1 , out_feat};
8184 cmsis_nn_dims bias_dims = {1 , 1 , 1 , out_feat};
82- cmsis_nn_dims output_dims = {batches, 1 , 1 , out_feat};
83-
84- arm_cmsis_nn_status status = arm_fully_connected_s8 (
85- &ctx,
86- &fc_params,
87- &per_tensor_quant_params,
88- &input_dims,
89- input_data,
90- &filter_dims,
91- weight_data,
92- &bias_dims,
93- bias_data,
94- &output_dims,
95- output_data);
96-
97- if (status != ARM_CMSIS_NN_SUCCESS) {
98- ET_LOG (
99- Error,
100- " quantized_linear_out: CMSIS-NN failed with status [%d]" ,
101- status);
102- context.fail (Error::Internal);
103- return out;
104- }
85+ cmsis_nn_dims output_dims = {1 , 1 , 1 , out_feat};
86+
87+ arm_cmsis_nn_status status;
88+ for (int32_t b = 0 ; b < batch_size; b++) {
89+ const int8_t * batch_input = input_data + b * in_feat;
90+ int8_t * batch_output = output_data + b * out_feat;
10591
92+ ET_CHECK_MSG (
93+ batch_input != nullptr && weight_data != nullptr ,
94+ " Null input pointers" );
95+ ET_CHECK_MSG (in_feat > 0 && out_feat > 0 , " Invalid dimensions" );
96+
97+ if (is_per_channel) {
98+ cmsis_nn_per_channel_quant_params per_channel_quant_params;
99+ per_channel_quant_params.multiplier =
100+ const_cast <int32_t *>(weight_mult_data);
101+ per_channel_quant_params.shift = const_cast <int32_t *>(weight_shift_data);
102+
103+ status = arm_fully_connected_per_channel_s8 (
104+ &ctx,
105+ &fc_params,
106+ &per_channel_quant_params,
107+ &input_dims,
108+ batch_input,
109+ &filter_dims,
110+ weight_data,
111+ &bias_dims,
112+ bias_data,
113+ &output_dims,
114+ batch_output);
115+ } else {
116+ fc_params.filter_offset = -weight_zp_data[0 ];
117+ cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
118+ per_tensor_quant_params.multiplier = weight_mult_data[0 ];
119+ per_tensor_quant_params.shift = weight_shift_data[0 ];
120+
121+ status = arm_fully_connected_s8 (
122+ &ctx,
123+ &fc_params,
124+ &per_tensor_quant_params,
125+ &input_dims,
126+ batch_input,
127+ &filter_dims,
128+ weight_data,
129+ &bias_dims,
130+ bias_data,
131+ &output_dims,
132+ batch_output);
133+ }
134+
135+ if (status != ARM_CMSIS_NN_SUCCESS) {
136+ ET_LOG (
137+ Error,
138+ " quantized_linear_out: CMSIS-NN failed with status [%d]" ,
139+ status);
140+ context.fail (Error::Internal);
141+ return out;
142+ }
143+ }
106144 return out;
107145}
108146
147+ // Functional variant (stub, not used at runtime)
148+ Tensor quantized_linear (
149+ KernelRuntimeContext& context,
150+ const Tensor& input,
151+ const Scalar& input_zero_point,
152+ const Scalar& input_multiplier,
153+ const Scalar& input_shift,
154+ const Tensor& weights,
155+ const Tensor& weight_zero_point,
156+ const Tensor& weight_multiplier,
157+ const Tensor& weight_shift,
158+ const torch::executor::optional<Tensor>& bias,
159+ const Tensor& bias_multiplier,
160+ const Tensor& bias_shift,
161+ const Tensor& scratch_buffer,
162+ const Scalar& output_zero_point,
163+ const Scalar& in_features,
164+ const Scalar& out_features) {
165+ ET_LOG (Info, " quantized_linear: called" );
166+ assert (false );
167+ return const_cast <Tensor&>(input);
168+ }
169+
109170} // namespace native
110171} // namespace cortex_m
0 commit comments