|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 3 | + * All rights reserved. | 
|  | 4 | + * | 
|  | 5 | + * This source code is licensed under the BSD-style license found in the | 
|  | 6 | + * LICENSE file in the root directory of this source tree. | 
|  | 7 | + */ | 
|  | 8 | + | 
|  | 9 | +#include <executorch/runtime/kernel/kernel_includes.h> | 
|  | 10 | +#include <cinttypes> | 
|  | 11 | + | 
|  | 12 | +// Check for Helium/MVE support | 
|  | 13 | +#if defined(__ARM_FEATURE_MVE) && (__ARM_FEATURE_MVE & 1) | 
|  | 14 | +#include <arm_mve.h> | 
|  | 15 | +#define HAS_HELIUM_SIMD 1 | 
|  | 16 | +#endif | 
|  | 17 | + | 
|  | 18 | +namespace cortex_m { | 
|  | 19 | +namespace native { | 
|  | 20 | + | 
|  | 21 | +using Tensor = executorch::aten::Tensor; | 
|  | 22 | +using ScalarType = executorch::aten::ScalarType; | 
|  | 23 | +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; | 
|  | 24 | + | 
|  | 25 | +namespace { | 
|  | 26 | + | 
|  | 27 | +/** | 
|  | 28 | + * Asserts that the parameters are valid for float to int8 quantization. | 
|  | 29 | + */ | 
|  | 30 | +void check_dequantize_args( | 
|  | 31 | +    const Tensor& input, | 
|  | 32 | +    int64_t quant_min, | 
|  | 33 | +    int64_t quant_max, | 
|  | 34 | +    ScalarType dtype, | 
|  | 35 | +    Tensor& out) { | 
|  | 36 | +  // Ensure input is char type | 
|  | 37 | +  ET_CHECK_MSG( | 
|  | 38 | +      input.scalar_type() == ScalarType::Char, | 
|  | 39 | +      "input.scalar_type() %" PRId8 " is not char type", | 
|  | 40 | +      static_cast<int8_t>(input.scalar_type())); | 
|  | 41 | + | 
|  | 42 | +  // Check output dtype is float | 
|  | 43 | +  ET_CHECK_MSG( | 
|  | 44 | +      out.scalar_type() == ScalarType::Float, | 
|  | 45 | +      "out.scalar_type() %" PRId8 " is not float", | 
|  | 46 | +      static_cast<int8_t>(out.scalar_type())); | 
|  | 47 | + | 
|  | 48 | +  // Check dtype is int8 (Char) | 
|  | 49 | +  ET_CHECK_MSG( | 
|  | 50 | +      dtype == ScalarType::Char, | 
|  | 51 | +      "dtype %" PRId8 " is not int8 (Char)", | 
|  | 52 | +      static_cast<int8_t>(dtype)); | 
|  | 53 | + | 
|  | 54 | +  // Validate quant_min and quant_max for int8 | 
|  | 55 | +  int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min(); | 
|  | 56 | +  int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max(); | 
|  | 57 | + | 
|  | 58 | +  ET_CHECK_MSG( | 
|  | 59 | +      quant_min >= quant_min_lower_bound, | 
|  | 60 | +      "quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32 | 
|  | 61 | +      " actual quant_min: %" PRId64, | 
|  | 62 | +      quant_min_lower_bound, | 
|  | 63 | +      quant_min); | 
|  | 64 | + | 
|  | 65 | +  ET_CHECK_MSG( | 
|  | 66 | +      quant_max <= quant_max_upper_bound, | 
|  | 67 | +      "quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32 | 
|  | 68 | +      " actual quant_max: %" PRId64, | 
|  | 69 | +      quant_max_upper_bound, | 
|  | 70 | +      quant_max); | 
|  | 71 | +} | 
|  | 72 | + | 
|  | 73 | +/** | 
|  | 74 | + * Scalar implementation of quantization for a single value. | 
|  | 75 | + */ | 
|  | 76 | +template <typename K, typename T> | 
|  | 77 | +T dequantize_val( | 
|  | 78 | +    float scale, | 
|  | 79 | +    int32_t zero_point, | 
|  | 80 | +    K value, | 
|  | 81 | +    int64_t quant_min, | 
|  | 82 | +    int64_t quant_max) { | 
|  | 83 | +  (void)quant_min; | 
|  | 84 | +  (void)quant_max; | 
|  | 85 | +  return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale); | 
|  | 86 | +} | 
|  | 87 | + | 
|  | 88 | +} // namespace | 
|  | 89 | + | 
|  | 90 | +Tensor& dequantize_per_tensor_out( | 
|  | 91 | +    KernelRuntimeContext& context, | 
|  | 92 | +    const Tensor& input, | 
|  | 93 | +    double scale, | 
|  | 94 | +    int64_t zero_point, | 
|  | 95 | +    int64_t quant_min, | 
|  | 96 | +    int64_t quant_max, | 
|  | 97 | +    ScalarType dtype, | 
|  | 98 | +    Tensor& out) { | 
|  | 99 | +  // Ignore context for now | 
|  | 100 | +  (void)context; | 
|  | 101 | + | 
|  | 102 | +  // Resize output tensor to match input dimensions | 
|  | 103 | +  torch::executor::Error err = resize_tensor(out, input.sizes()); | 
|  | 104 | +  ET_CHECK_MSG( | 
|  | 105 | +      err == torch::executor::Error::Ok, | 
|  | 106 | +      "Failed to resize out Tensor in dequantize_per_tensor_out"); | 
|  | 107 | + | 
|  | 108 | +  // Validate input parameters | 
|  | 109 | +  check_dequantize_args(input, quant_min, quant_max, dtype, out); | 
|  | 110 | + | 
|  | 111 | +  // Pre-compute inverse scale for better performance | 
|  | 112 | +  int32_t zp = static_cast<int32_t>(zero_point); | 
|  | 113 | +  int32_t qmin = static_cast<int32_t>(quant_min); | 
|  | 114 | +  int32_t qmax = static_cast<int32_t>(quant_max); | 
|  | 115 | + | 
|  | 116 | +  // Get pointers to input and output data | 
|  | 117 | +  const int8_t* input_data = input.const_data_ptr<int8_t>(); | 
|  | 118 | +  float* out_data = out.mutable_data_ptr<float>(); | 
|  | 119 | +  const size_t numel = input.numel(); | 
|  | 120 | + | 
|  | 121 | +#if defined(HAS_HELIUM_SIMD) | 
|  | 122 | +// Helium MVE implementation for float32 to int8 quantization | 
|  | 123 | +#Error "Implement MVE version!" | 
|  | 124 | +#else | 
|  | 125 | +  // Scalar implementation for float32 to int8 quantization | 
|  | 126 | +  for (size_t i = 0; i < numel; i++) { | 
|  | 127 | +    out_data[i] = | 
|  | 128 | +        dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax); | 
|  | 129 | +  } | 
|  | 130 | +#endif | 
|  | 131 | + | 
|  | 132 | +  return out; | 
|  | 133 | +} | 
|  | 134 | + | 
|  | 135 | +Tensor& dequantize_per_tensor_out( | 
|  | 136 | +    const Tensor& input, | 
|  | 137 | +    double scale, | 
|  | 138 | +    int64_t zero_point, | 
|  | 139 | +    int64_t quant_min, | 
|  | 140 | +    int64_t quant_max, | 
|  | 141 | +    ScalarType dtype, | 
|  | 142 | +    Tensor& out) { | 
|  | 143 | +  KernelRuntimeContext context; | 
|  | 144 | +  return dequantize_per_tensor_out( | 
|  | 145 | +      context, input, scale, zero_point, quant_min, quant_max, dtype, out); | 
|  | 146 | +} | 
|  | 147 | + | 
|  | 148 | +} // namespace native | 
|  | 149 | +} // namespace cortex_m | 
0 commit comments