1111#include < algorithm>
1212#include < cinttypes>
1313#include < cmath>
14+ #if defined(__aarch64__) || defined(__ARM_NEON)
15+ #include < arm_neon.h>
16+ #endif
1417
1518/* *
1619 * For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
2225using Tensor = exec_aten::Tensor;
2326using Scalar = exec_aten::Scalar;
2427using ScalarType = exec_aten::ScalarType;
28+ using StridesType = exec_aten::StridesType;
29+ using SizesType = exec_aten::SizesType;
2530
2631namespace {
2732
@@ -62,6 +67,183 @@ void check_dequantize_per_tensor_args(
6267 quant_max);
6368}
6469
70+ /* *
71+ * Useful to reduce a tensor `in` over a given dimension `dim` using the
72+ * reduce function `fn`, which should have the following signature:
73+ * void fn(const size_t size, const size_t stride, const size_t base_ix)
74+ * where `size` and `stride` are the size and stride of the dimension being
75+ * reduced and `base_ix` is the index of the first element of the reduction.
76+ */
77+ template <typename Fn>
78+ void apply_over_unpacked_dim (
79+ const Fn& fn,
80+ const exec_aten::Tensor& in,
81+ const int64_t & dim) {
82+ if (in.numel () == 0 ) {
83+ return ;
84+ }
85+
86+ ET_CHECK_MSG (in.dim () > 0 , " Input tensor must have at least one dimension" );
87+ ET_CHECK_VALID_DIM (dim, in.dim ());
88+
89+ const size_t d = ET_NORMALIZE_IX (dim, in.dim ());
90+ const size_t dim_size = in.size (d);
91+ const size_t outer_size = getLeadingDims (in, d);
92+ const size_t inner_size = getTrailingDims (in, d);
93+ // Loop through all outer dimensions
94+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
95+ // Loop through dim
96+ for (size_t unpacked_dim_idx = 0 ; unpacked_dim_idx < dim_size;
97+ ++unpacked_dim_idx) {
98+ fn (inner_size, outer_idx, unpacked_dim_idx);
99+ }
100+ }
101+ }
102+
103+ void dequantize_optimized (
104+ const int8_t * in,
105+ const double scale,
106+ const int64_t zero_point,
107+ float * out,
108+ int64_t quant_min,
109+ int64_t quant_max,
110+ size_t numel) {
111+ ET_CHECK_MSG (
112+ zero_point >= quant_min,
113+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
114+ zero_point,
115+ quant_min);
116+ ET_CHECK_MSG (
117+ zero_point <= quant_max,
118+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
119+ zero_point,
120+ quant_max);
121+ size_t i = 0 ;
122+ #if defined(__aarch64__) || defined(__ARM_NEON)
123+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
124+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
125+ constexpr int32_t kVecSize = 16 ;
126+ const size_t num_vecs = numel / kVecSize ;
127+ const int8_t * in_copy = in;
128+ float * out_copy = out;
129+ for (; i < num_vecs; i++) {
130+ int8x16_t in_vec = vld1q_s8 (in_copy);
131+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
132+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
133+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
134+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
135+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
136+
137+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
138+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
139+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
140+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
141+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
142+ vst1q_f32 (out_copy + 0 , out_vec_0_3);
143+ vst1q_f32 (out_copy + 4 , out_vec_4_7);
144+ vst1q_f32 (out_copy + 8 , out_vec_8_11);
145+ vst1q_f32 (out_copy + 12 , out_vec_12_15);
146+ in_copy += kVecSize ;
147+ out_copy += kVecSize ;
148+ }
149+ i = i * kVecSize ;
150+ #endif
151+ for (; i < numel; i++) {
152+ out[i] = (in[i] - zero_point) * scale;
153+ }
154+ }
155+
156+ float get_scale (const Tensor& scale, size_t channel_ix) {
157+ ET_CHECK_MSG (
158+ (scale.scalar_type () == ScalarType::Double) ||
159+ (scale.scalar_type () == ScalarType::Float),
160+ " scale.scalar_type() %" PRId8 " is not double or float type" ,
161+ static_cast <int8_t >(scale.scalar_type ()));
162+ if (scale.scalar_type () == ScalarType::Double) {
163+ return static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
164+ } else {
165+ return scale.const_data_ptr <float >()[channel_ix];
166+ }
167+ }
168+
169+ bool can_use_optimized_dequantize_per_channel (
170+ const Tensor& in,
171+ const ScalarType in_dtype,
172+ exec_aten::optional<ScalarType>& out_dtype) {
173+ bool is_contiguous = false ;
174+ #ifdef USE_ATEN_LIB
175+ is_contiguous = in.is_contiguous ();
176+ #else
177+ is_contiguous = executorch::runtime::is_contiguous_dim_order (
178+ in.dim_order ().data (), in.dim ());
179+ #endif
180+ if (!is_contiguous || (in_dtype != ScalarType::Char) ||
181+ (out_dtype.has_value () && out_dtype.value () != ScalarType::Float)) {
182+ return false ;
183+ }
184+ return true ;
185+ }
186+
187+ void dequantize_per_channel_optimized (
188+ const Tensor& in,
189+ const Tensor& scales,
190+ const optional<Tensor>& opt_zero_points,
191+ Tensor& out,
192+ int64_t axis,
193+ int64_t quant_min,
194+ int64_t quant_max,
195+ ScalarType in_dtype,
196+ exec_aten::optional<ScalarType>& out_dtype) {
197+ check_dequantize_per_tensor_args (
198+ in, quant_min, quant_max, in_dtype, out_dtype, out);
199+ ET_CHECK_MSG (
200+ in_dtype == ScalarType::Char,
201+ " in.scalar_type() %" PRId8 " is not supported:" ,
202+ static_cast <int8_t >(in.scalar_type ()));
203+ if (out_dtype.has_value ()) {
204+ ET_CHECK_MSG (
205+ out_dtype.value () == ScalarType::Float,
206+ " Only float output is supported" );
207+ }
208+ const int8_t * in_data = in.const_data_ptr <int8_t >();
209+ float * out_data = out.mutable_data_ptr <float >();
210+ const int64_t * zero_points_data = nullptr ;
211+ if (opt_zero_points.has_value ()) {
212+ zero_points_data = opt_zero_points.value ().const_data_ptr <int64_t >();
213+ }
214+ const StridesType axis_stride = in.strides ()[axis];
215+ const StridesType outer_stride = in.size (axis) * axis_stride;
216+ apply_over_unpacked_dim (
217+ [in_data,
218+ out_data,
219+ &scales,
220+ zero_points_data,
221+ axis_stride,
222+ outer_stride,
223+ quant_min,
224+ quant_max](
225+ SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
226+ const int8_t * in_data_local =
227+ in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
228+ const double scale = get_scale (scales, unpacked_dim_idx);
229+ const int64_t zero_point = zero_points_data != nullptr
230+ ? zero_points_data[unpacked_dim_idx]
231+ : 0 ;
232+ float * out_data_local = out_data + outer_idx * outer_stride +
233+ unpacked_dim_idx * axis_stride;
234+ dequantize_optimized (
235+ in_data_local,
236+ scale,
237+ zero_point,
238+ out_data_local,
239+ quant_min,
240+ quant_max,
241+ numel);
242+ },
243+ in,
244+ axis);
245+ }
246+
65247} // namespace
66248
67249/* *
@@ -170,19 +352,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
170352 return out;
171353}
172354
173- float get_scale (const Tensor& scale, size_t channel_ix) {
174- ET_CHECK_MSG (
175- (scale.scalar_type () == ScalarType::Double) ||
176- (scale.scalar_type () == ScalarType::Float),
177- " scale.scalar_type() %" PRId8 " is not double or float type" ,
178- static_cast <int8_t >(scale.scalar_type ()));
179- if (scale.scalar_type () == ScalarType::Double) {
180- return static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
181- } else {
182- return scale.const_data_ptr <float >()[channel_ix];
183- }
184- }
185-
186355Tensor& dequantize_per_channel_out (
187356 const Tensor& input,
188357 const Tensor& scale,
@@ -227,6 +396,20 @@ Tensor& dequantize_per_channel_out(
227396 check_dequantize_per_tensor_args (
228397 input, quant_min, quant_max, dtype, out_dtype, out);
229398
399+ if (can_use_optimized_dequantize_per_channel (input, dtype, out_dtype)) {
400+ dequantize_per_channel_optimized (
401+ input,
402+ scale,
403+ opt_zero_points,
404+ out,
405+ axis,
406+ quant_min,
407+ quant_max,
408+ dtype,
409+ out_dtype);
410+ return out;
411+ }
412+
230413 // a list contains all dimensions except axis
231414 int64_t dims[kTensorDimensionLimit ];
232415 for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
0 commit comments