|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 3 | + * All rights reserved. | 
|  | 4 | + * | 
|  | 5 | + * This source code is licensed under the BSD-style license found in the | 
|  | 6 | + * LICENSE file in the root directory of this source tree. | 
|  | 7 | + */ | 
|  | 8 | + | 
|  | 9 | +#include <executorch/runtime/kernel/kernel_includes.h> | 
|  | 10 | +#include <algorithm> | 
|  | 11 | +#include <cinttypes> | 
|  | 12 | +#include <cmath> | 
|  | 13 | + | 
|  | 14 | +// Check for Helium/MVE support | 
|  | 15 | +#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1) | 
|  | 16 | +#include <arm_mve.h> | 
|  | 17 | +#define HAS_HELIUM_SIMD 1 | 
|  | 18 | +#endif | 
|  | 19 | + | 
|  | 20 | +namespace cortex_m { | 
|  | 21 | +namespace native { | 
|  | 22 | + | 
|  | 23 | +using Tensor = executorch::aten::Tensor; | 
|  | 24 | +using ScalarType = executorch::aten::ScalarType; | 
|  | 25 | +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; | 
|  | 26 | + | 
|  | 27 | +namespace { | 
|  | 28 | + | 
|  | 29 | +/** | 
|  | 30 | + * Asserts that the parameters are valid for float to int8 quantization. | 
|  | 31 | + */ | 
|  | 32 | +void check_quantize_args( | 
|  | 33 | +    const Tensor& input, | 
|  | 34 | +    int64_t quant_min, | 
|  | 35 | +    int64_t quant_max, | 
|  | 36 | +    ScalarType dtype, | 
|  | 37 | +    Tensor& out) { | 
|  | 38 | +  // Ensure input is float type | 
|  | 39 | +  ET_CHECK_MSG( | 
|  | 40 | +      input.scalar_type() == ScalarType::Float, | 
|  | 41 | +      "input.scalar_type() %" PRId8 " is not float type", | 
|  | 42 | +      static_cast<int8_t>(input.scalar_type())); | 
|  | 43 | + | 
|  | 44 | +  // Check output dtype is int8 (Char) | 
|  | 45 | +  ET_CHECK_MSG( | 
|  | 46 | +      out.scalar_type() == ScalarType::Char, | 
|  | 47 | +      "out.scalar_type() %" PRId8 " is not int8 (Char)", | 
|  | 48 | +      static_cast<int8_t>(out.scalar_type())); | 
|  | 49 | + | 
|  | 50 | +  // Check dtype is int8 (Char) | 
|  | 51 | +  ET_CHECK_MSG( | 
|  | 52 | +      dtype == ScalarType::Char, | 
|  | 53 | +      "dtype %" PRId8 " is not int8 (Char)", | 
|  | 54 | +      static_cast<int8_t>(dtype)); | 
|  | 55 | + | 
|  | 56 | +  // Validate quant_min and quant_max for int8 | 
|  | 57 | +  int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min(); | 
|  | 58 | +  int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max(); | 
|  | 59 | + | 
|  | 60 | +  ET_CHECK_MSG( | 
|  | 61 | +      quant_min >= quant_min_lower_bound, | 
|  | 62 | +      "quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32 | 
|  | 63 | +      " actual quant_min: %" PRId64, | 
|  | 64 | +      quant_min_lower_bound, | 
|  | 65 | +      quant_min); | 
|  | 66 | + | 
|  | 67 | +  ET_CHECK_MSG( | 
|  | 68 | +      quant_max <= quant_max_upper_bound, | 
|  | 69 | +      "quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32 | 
|  | 70 | +      " actual quant_max: %" PRId64, | 
|  | 71 | +      quant_max_upper_bound, | 
|  | 72 | +      quant_max); | 
|  | 73 | +} | 
|  | 74 | + | 
|  | 75 | +/** | 
|  | 76 | + * Scalar implementation of quantization for a single value. | 
|  | 77 | + */ | 
|  | 78 | +template <typename T, typename K> | 
|  | 79 | +T quantize_val( | 
|  | 80 | +    float inv_scale, | 
|  | 81 | +    int32_t zero_point, | 
|  | 82 | +    K value, | 
|  | 83 | +    int64_t quant_min, | 
|  | 84 | +    int64_t quant_max) { | 
|  | 85 | +  int32_t qvalue = | 
|  | 86 | +      zero_point + static_cast<int32_t>(std::nearbyint(inv_scale * value)); | 
|  | 87 | +  qvalue = std::max<int32_t>(qvalue, static_cast<int32_t>(quant_min)); | 
|  | 88 | +  qvalue = std::min<int32_t>(qvalue, static_cast<int32_t>(quant_max)); | 
|  | 89 | +  return static_cast<T>(qvalue); | 
|  | 90 | +} | 
|  | 91 | + | 
|  | 92 | +} // namespace | 
|  | 93 | + | 
|  | 94 | +Tensor& quantize_per_tensor_out( | 
|  | 95 | +    KernelRuntimeContext& context, | 
|  | 96 | +    const Tensor& input, | 
|  | 97 | +    double scale, | 
|  | 98 | +    int64_t zero_point, | 
|  | 99 | +    int64_t quant_min, | 
|  | 100 | +    int64_t quant_max, | 
|  | 101 | +    ScalarType dtype, | 
|  | 102 | +    Tensor& out) { | 
|  | 103 | +  // Ignore context for now | 
|  | 104 | +  (void)context; | 
|  | 105 | + | 
|  | 106 | +  // Resize output tensor to match input dimensions | 
|  | 107 | +  torch::executor::Error err = resize_tensor(out, input.sizes()); | 
|  | 108 | +  ET_CHECK_MSG( | 
|  | 109 | +      err == torch::executor::Error::Ok, | 
|  | 110 | +      "Failed to resize out Tensor in quantize_per_tensor_out"); | 
|  | 111 | + | 
|  | 112 | +  // Validate input parameters | 
|  | 113 | +  check_quantize_args(input, quant_min, quant_max, dtype, out); | 
|  | 114 | + | 
|  | 115 | +  // Pre-compute inverse scale for better performance | 
|  | 116 | +  float inv_scale = 1.0f / static_cast<float>(scale); | 
|  | 117 | +  int32_t zp = static_cast<int32_t>(zero_point); | 
|  | 118 | +  int32_t qmin = static_cast<int32_t>(quant_min); | 
|  | 119 | +  int32_t qmax = static_cast<int32_t>(quant_max); | 
|  | 120 | + | 
|  | 121 | +  // Get pointers to input and output data | 
|  | 122 | +  const float* input_data = input.const_data_ptr<float>(); | 
|  | 123 | +  int8_t* out_data = out.mutable_data_ptr<int8_t>(); | 
|  | 124 | +  const size_t numel = input.numel(); | 
|  | 125 | + | 
|  | 126 | +#if defined(HAS_HELIUM_SIMD) | 
|  | 127 | +// Helium MVE implementation for float32 to int8 quantization | 
|  | 128 | +#Error "Implement MVE version!" | 
|  | 129 | +#else | 
|  | 130 | +  // Scalar implementation for float32 to int8 quantization | 
|  | 131 | +  for (size_t i = 0; i < numel; i++) { | 
|  | 132 | +    out_data[i] = | 
|  | 133 | +        quantize_val<int8_t, float>(inv_scale, zp, input_data[i], qmin, qmax); | 
|  | 134 | +  } | 
|  | 135 | +#endif | 
|  | 136 | + | 
|  | 137 | +  return out; | 
|  | 138 | +} | 
|  | 139 | + | 
|  | 140 | +Tensor& quantize_per_tensor_out( | 
|  | 141 | +    const Tensor& input, | 
|  | 142 | +    double scale, | 
|  | 143 | +    int64_t zero_point, | 
|  | 144 | +    int64_t quant_min, | 
|  | 145 | +    int64_t quant_max, | 
|  | 146 | +    ScalarType dtype, | 
|  | 147 | +    Tensor& out) { | 
|  | 148 | +  KernelRuntimeContext context; | 
|  | 149 | +  return quantize_per_tensor_out( | 
|  | 150 | +      context, input, scale, zero_point, quant_min, quant_max, dtype, out); | 
|  | 151 | +} | 
|  | 152 | + | 
|  | 153 | +} // namespace native | 
|  | 154 | +} // namespace cortex_m | 
0 commit comments