|
| 1 | +#include <executorch/runtime/kernel/kernel_includes.h> |
| 2 | +#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor |
| 3 | +#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar |
| 4 | + |
| 5 | +#include <vector> |
| 6 | +#include <algorithm> |
| 7 | +#include <cmath> |
| 8 | +#include <cstdint> |
| 9 | + |
| 10 | +extern "C" { |
| 11 | +#include "Include/arm_nnfunctions.h" |
| 12 | +} |
| 13 | + |
| 14 | +namespace cortex_m { |
| 15 | +namespace native { |
| 16 | + |
| 17 | +using Tensor = torch::executor::Tensor; |
| 18 | +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; |
| 19 | + |
| 20 | +// Determine quantization scale from fp32 data |
| 21 | +float determine_input_scale(const float* data, int size) { |
| 22 | + float min_val = *std::min_element(data, data + size); |
| 23 | + float max_val = *std::max_element(data, data + size); |
| 24 | + return (max_val - min_val) / 255.0f; // For int8 range [-128, 127] |
| 25 | +} |
| 26 | +// Quantize fp32 to int8 |
| 27 | +void quantize_tensor(const float* input, int8_t* output, int size, |
| 28 | + float scale, int32_t zero_point) { |
| 29 | + for (int i = 0; i < size; i++) { |
| 30 | + int32_t quantized = std::round(input[i] / scale) + zero_point; |
| 31 | + // This ensures that the value quantized stays within the specified bounds — in this case, between -128 and 127, |
| 32 | + // which are the limits of int8_t. |
| 33 | + output[i] = std::clamp(quantized, static_cast<int32_t>(-128), static_cast<int32_t>(127)); |
| 34 | + } |
| 35 | +} |
| 36 | +// Dequantize int8 to fp32 |
| 37 | +void dequantize_tensor(const int8_t* input, float* output, int size, |
| 38 | + float scale, int32_t zero_point) { |
| 39 | + for (int i = 0; i < size; i++) { |
| 40 | + output[i] = (input[i] - zero_point) * scale; |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +// Converts a floating-point scale to CMSIS-NN fixed-point multiplier and shift |
| 45 | +// scale: the floating-point scale factor from ExecuTorch quantization |
| 46 | +// multiplier: output fixed-point multiplier (Q31 format) |
| 47 | +// shift: output left shift amount (positive means left shift) |
| 48 | +// diff_min: output minimum difference threshold (usually -128 for int8) |
| 49 | +void convert_scale_to_cmsis_params(float scale, int32_t* multiplier, int32_t* shift, int32_t* diff_min) { |
| 50 | + if (scale == 0.0f) { |
| 51 | + *multiplier = 0; |
| 52 | + *shift = 0; |
| 53 | + *diff_min = -128; |
| 54 | + return; |
| 55 | + } |
| 56 | + // Decompose scale into mantissa and exponent: scale = mantissa * 2^exponent |
| 57 | + int exponent; |
| 58 | + float mantissa = std::frexp(scale, &exponent); // mantissa in [0.5, 1) |
| 59 | + // Convert mantissa to Q31 fixed-point format |
| 60 | + int64_t q_fixed = static_cast<int64_t>(std::round(mantissa * (1ll << 31))); |
| 61 | + // Adjust multiplier and shift for CMSIS-NN |
| 62 | + *multiplier = static_cast<int32_t>(q_fixed); |
| 63 | + // CMSIS-NN expects a left shift, so negate exponent to get shift |
| 64 | + *shift = -exponent; |
| 65 | + // Typical diff_min for int8 softmax |
| 66 | + *diff_min = -128; |
| 67 | +} |
| 68 | + |
| 69 | +torch::executor::Tensor& aten_softmax( |
| 70 | + KernelRuntimeContext& context, |
| 71 | + const Tensor& self, |
| 72 | + int64_t dim, |
| 73 | + bool half_to_float, |
| 74 | + Tensor& out) { |
| 75 | + |
| 76 | + ET_LOG(Info, "CMSIS-NN quantized softmax kernel called"); |
| 77 | + |
| 78 | + // Step 1: Extract fp32 data |
| 79 | + const float* input_data_fp32 = self.data_ptr<float>(); |
| 80 | + float* output_data_fp32 = out.data_ptr<float>(); |
| 81 | + |
| 82 | + // Step 2: Get tensor dimensions |
| 83 | + int rows = self.sizes()[0]; |
| 84 | + int cols = self.sizes()[1]; |
| 85 | + |
| 86 | + // Step 3: Quantize input (fp32 -> int8) |
| 87 | + // Determine appropriate scale/zero_point |
| 88 | + float input_scale = determine_input_scale(input_data_fp32, rows * cols); |
| 89 | + |
| 90 | + // '0' a reasonable default for symmetric quantization in int8, |
| 91 | + // especially if the input data is centered around zero else TBD |
| 92 | + int32_t input_zero_point = 0; |
| 93 | + |
| 94 | + std::vector<int8_t> input_quantized(rows * cols); |
| 95 | + quantize_tensor(input_data_fp32, input_quantized.data(), |
| 96 | + rows * cols, input_scale, input_zero_point); |
| 97 | + |
| 98 | + // Step 4: Convert to CMSIS-NN parameters |
| 99 | + int32_t input_mult, input_shift, diff_min; |
| 100 | + convert_scale_to_cmsis_params(input_scale, &input_mult, &input_shift, &diff_min); |
| 101 | + |
| 102 | + // Step 5: Call CMSIS-NN kernel |
| 103 | + std::vector<int8_t> output_quantized(rows * cols); |
| 104 | + arm_softmax_s8(input_quantized.data(), rows, cols, |
| 105 | + input_mult, input_shift, diff_min, |
| 106 | + output_quantized.data()); |
| 107 | + |
| 108 | + // Step 6: Dequantize output (int8 -> fp32) |
| 109 | + dequantize_tensor(output_quantized.data(), output_data_fp32, |
| 110 | + rows * cols, input_scale, input_zero_point); |
| 111 | + |
| 112 | + return out; |
| 113 | +} |
| 114 | + |
| 115 | +} // namespace native |
| 116 | +} // namespace cortex_m |
0 commit comments