1111#include < algorithm>
1212#include < cinttypes>
1313#include < cmath>
14+ #if defined(__aarch64__) || defined(__ARM_NEON)
15+ #include < arm_neon.h>
16+ #endif
1417
1518/* *
1619 * For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
2225using Tensor = exec_aten::Tensor;
2326using Scalar = exec_aten::Scalar;
2427using ScalarType = exec_aten::ScalarType;
28+ using StridesType = exec_aten::StridesType;
29+ using SizesType = exec_aten::SizesType;
2530
2631namespace {
2732
@@ -61,6 +66,171 @@ void check_dequantize_per_tensor_args(
6166 quant_max);
6267}
6368
69+ /* *
70+ * Useful to reduce a tensor `in` over a given dimension `dim` using the
71+ * reduce function `fn`, which should have the following signature:
72+ * void fn(const size_t size, const size_t stride, const size_t base_ix)
73+ * where `size` and `stride` are the size and stride of the dimension being
74+ * reduced and `base_ix` is the index of the first element of the reduction.
75+ */
76+ template <typename Fn>
77+ void apply_over_unpacked_dim (
78+ const Fn& fn,
79+ const exec_aten::Tensor& in,
80+ const int64_t & dim) {
81+ if (in.numel () == 0 ) {
82+ return ;
83+ }
84+
85+ ET_CHECK_MSG (in.dim () > 0 , " Input tensor must have at least one dimension" );
86+ ET_CHECK_VALID_DIM (dim, in.dim ());
87+
88+ const size_t d = ET_NORMALIZE_IX (dim, in.dim ());
89+ const size_t dim_size = in.size (d);
90+ const size_t outer_size = getLeadingDims (in, d);
91+ const size_t inner_size = getTrailingDims (in, d);
92+ // Loop through all outer dimensions
93+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
94+ // Loop through dim
95+ for (size_t unpacked_dim_idx = 0 ; unpacked_dim_idx < dim_size;
96+ ++unpacked_dim_idx) {
97+ fn (inner_size, outer_idx, unpacked_dim_idx);
98+ }
99+ }
100+ }
101+
102+ void dequantize_optimized (
103+ const int8_t * in,
104+ const double scale,
105+ const int64_t zero_point,
106+ float * out,
107+ int64_t quant_min,
108+ int64_t quant_max,
109+ size_t numel) {
110+ ET_CHECK_MSG (
111+ zero_point >= quant_min,
112+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
113+ zero_point,
114+ quant_min);
115+ ET_CHECK_MSG (
116+ zero_point <= quant_max,
117+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
118+ zero_point,
119+ quant_max);
120+ size_t i = 0 ;
121+ #if defined(__aarch64__) || defined(__ARM_NEON)
122+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
123+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
124+ constexpr int32_t kVecSize = 16 ;
125+ const size_t num_vecs = numel / kVecSize ;
126+ const int8_t * in_copy = in;
127+ float * out_copy = out;
128+ for (; i < num_vecs; i++) {
129+ int8x16_t in_vec = vld1q_s8 (in_copy);
130+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
131+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
132+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
133+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
134+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
135+
136+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
137+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
138+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
139+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
140+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
141+ vst1q_f32 (out_copy + 0 , out_vec_0_3);
142+ vst1q_f32 (out_copy + 4 , out_vec_4_7);
143+ vst1q_f32 (out_copy + 8 , out_vec_8_11);
144+ vst1q_f32 (out_copy + 12 , out_vec_12_15);
145+ in_copy += kVecSize ;
146+ out_copy += kVecSize ;
147+ }
148+ i = i * kVecSize ;
149+ #endif
150+ for (; i < numel; i++) {
151+ out[i] = (in[i] - zero_point) * scale;
152+ }
153+ }
154+
155+ bool can_use_optimized_dequantize_per_channel (
156+ const Tensor& in,
157+ const ScalarType in_dtype,
158+ exec_aten::optional<ScalarType>& out_dtype) {
159+ bool is_contiguous = false ;
160+ #ifdef USE_ATEN_LIB
161+ is_contiguous = in.is_contiguous ();
162+ #else
163+ is_contiguous = executorch::runtime::is_contiguous_dim_order (
164+ in.dim_order ().data (), in.dim ());
165+ #endif
166+ if (!is_contiguous || (in_dtype != ScalarType::Char) ||
167+ (out_dtype.has_value () && out_dtype.value () != ScalarType::Float)) {
168+ return false ;
169+ }
170+ return true ;
171+ }
172+
173+ void dequantize_per_channel_optimized (
174+ const Tensor& in,
175+ const Tensor& scales,
176+ const optional<Tensor>& opt_zero_points,
177+ Tensor& out,
178+ int64_t axis,
179+ int64_t quant_min,
180+ int64_t quant_max,
181+ ScalarType in_dtype,
182+ exec_aten::optional<ScalarType>& out_dtype) {
183+ check_dequantize_per_tensor_args (
184+ in, quant_min, quant_max, in_dtype, out_dtype, out);
185+ ET_CHECK_MSG (
186+ in_dtype == ScalarType::Char,
187+ " in.scalar_type() %" PRId8 " is not supported:" ,
188+ static_cast <int8_t >(in.scalar_type ()));
189+ if (out_dtype.has_value ()) {
190+ ET_CHECK_MSG (
191+ out_dtype.value () == ScalarType::Float,
192+ " Only float output is supported" );
193+ }
194+ const int8_t * in_data = in.const_data_ptr <int8_t >();
195+ float * out_data = out.mutable_data_ptr <float >();
196+ const int64_t * zero_points_data = nullptr ;
197+ if (opt_zero_points.has_value ()) {
198+ zero_points_data = opt_zero_points.value ().const_data_ptr <int64_t >();
199+ }
200+ const double * scales_data = scales.const_data_ptr <double >();
201+ const StridesType axis_stride = in.strides ()[axis];
202+ const StridesType outer_stride = in.size (axis) * axis_stride;
203+ apply_over_unpacked_dim (
204+ [in_data,
205+ out_data,
206+ scales_data,
207+ zero_points_data,
208+ axis_stride,
209+ outer_stride,
210+ quant_min,
211+ quant_max](
212+ SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
213+ const int8_t * in_data_local =
214+ in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
215+ const double scale = scales_data[unpacked_dim_idx];
216+ const int64_t zero_point = zero_points_data != nullptr
217+ ? zero_points_data[unpacked_dim_idx]
218+ : 0 ;
219+ float * out_data_local = out_data + outer_idx * outer_stride +
220+ unpacked_dim_idx * axis_stride;
221+ dequantize_optimized (
222+ in_data_local,
223+ scale,
224+ zero_point,
225+ out_data_local,
226+ quant_min,
227+ quant_max,
228+ numel);
229+ },
230+ in,
231+ axis);
232+ }
233+
64234} // namespace
65235
66236/* *
@@ -225,6 +395,20 @@ Tensor& dequantize_per_channel_out(
225395 check_dequantize_per_tensor_args (
226396 input, quant_min, quant_max, dtype, out_dtype, out);
227397
398+ if (can_use_optimized_dequantize_per_channel (input, dtype, out_dtype)) {
399+ dequantize_per_channel_optimized (
400+ input,
401+ scale,
402+ opt_zero_points,
403+ out,
404+ axis,
405+ quant_min,
406+ quant_max,
407+ dtype,
408+ out_dtype);
409+ return out;
410+ }
411+
228412 // a list contains all dimensions except axis
229413 int64_t dims[kTensorDimensionLimit ];
230414 for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
0 commit comments