1111#include  < algorithm> 
1212#include  < cinttypes> 
1313#include  < cmath> 
14+ #if  defined(__aarch64__) || defined(__ARM_NEON)
15+ #include  < arm_neon.h> 
16+ #endif 
1417
1518/* *
1619 * For an input tensor, use the scale and zero_point arguments to quantize it. 
@@ -22,6 +25,8 @@ namespace native {
2225using  Tensor = exec_aten::Tensor;
2326using  Scalar = exec_aten::Scalar;
2427using  ScalarType = exec_aten::ScalarType;
28+ using  StridesType = exec_aten::StridesType;
29+ using  SizesType = exec_aten::SizesType;
2530
2631namespace  {
2732
@@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args(
6368      quant_max);
6469}
6570
71+ /* *
72+  * Useful to reduce a tensor `in` over a given dimension `dim` using the 
73+  * reduce function `fn`, which should have the following signature: 
74+  * void fn(const size_t size, const size_t stride, const size_t base_ix) 
75+  * where `size` and `stride` are the size and stride of the dimension being 
76+  * reduced and `base_ix` is the index of the first element of the reduction. 
77+  */  
78+ template  <typename  Fn>
79+ void  apply_over_unpacked_dim (
80+     const  Fn& fn,
81+     const  exec_aten::Tensor& in,
82+     const  int64_t & dim) {
83+   if  (in.numel () == 0 ) {
84+     return ;
85+   }
86+ 
87+   ET_CHECK_MSG (in.dim () > 0 , " Input tensor must have at least one dimension"  );
88+   ET_CHECK_VALID_DIM (dim, in.dim ());
89+ 
90+   const  size_t  d = ET_NORMALIZE_IX (dim, in.dim ());
91+   const  size_t  dim_size = in.size (d);
92+   const  size_t  outer_size = getLeadingDims (in, d);
93+   const  size_t  inner_size = getTrailingDims (in, d);
94+   //  Loop through all outer dimensions
95+   for  (size_t  outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
96+     //  Loop through dim
97+     for  (size_t  unpacked_dim_idx = 0 ; unpacked_dim_idx < dim_size;
98+          ++unpacked_dim_idx) {
99+       fn (inner_size, outer_idx, unpacked_dim_idx);
100+     }
101+   }
102+ }
103+ 
104+ void  dequantize_optimized (
105+     const  int8_t * in,
106+     const  double  scale,
107+     const  int64_t  zero_point,
108+     float * out,
109+     int64_t  quant_min,
110+     int64_t  quant_max,
111+     size_t  numel) {
112+   ET_CHECK_MSG (
113+       zero_point >= quant_min,
114+       " zero_point must be %"   PRId64 "  <= quant_min %"   PRId64,
115+       zero_point,
116+       quant_min);
117+   ET_CHECK_MSG (
118+       zero_point <= quant_max,
119+       " zero_point must be %"   PRId64 "  >= quant_max %"   PRId64,
120+       zero_point,
121+       quant_max);
122+   size_t  i = 0 ;
123+ #if  defined(__aarch64__) || defined(__ARM_NEON)
124+   int8x8_t  zero_point_vec = vdup_n_s8 (zero_point);
125+   float32x4_t  scales = vdupq_n_f32 (static_cast <float >(scale));
126+   constexpr  int32_t  kVecSize  = 16 ;
127+   const  size_t  num_vecs = numel / kVecSize ;
128+   const  int8_t * in_copy = in;
129+   float * out_copy = out;
130+   for  (; i < num_vecs; i++) {
131+     int8x16_t  in_vec = vld1q_s8 (in_copy);
132+     int16x8_t  sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
133+     int32x4_t  sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
134+     int32x4_t  sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
135+     float32x4_t  out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
136+     float32x4_t  out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
137+ 
138+     int16x8_t  sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
139+     int32x4_t  sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
140+     int32x4_t  sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
141+     float32x4_t  out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
142+     float32x4_t  out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
143+     vst1q_f32 (out_copy + 0 , out_vec_0_3);
144+     vst1q_f32 (out_copy + 4 , out_vec_4_7);
145+     vst1q_f32 (out_copy + 8 , out_vec_8_11);
146+     vst1q_f32 (out_copy + 12 , out_vec_12_15);
147+     in_copy += kVecSize ;
148+     out_copy += kVecSize ;
149+   }
150+   i = i * kVecSize ;
151+ #endif 
152+   for  (; i < numel; i++) {
153+     out[i] = (in[i] - zero_point) * scale;
154+   }
155+ }
156+ 
157+ float  get_scale (const  Tensor& scale, size_t  channel_ix) {
158+   ET_CHECK_MSG (
159+       (scale.scalar_type () == ScalarType::Double) ||
160+           (scale.scalar_type () == ScalarType::Float),
161+       " scale.scalar_type() %"   PRId8 "  is not double or float type"  ,
162+       static_cast <int8_t >(scale.scalar_type ()));
163+   if  (scale.scalar_type () == ScalarType::Double) {
164+     return  static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
165+   } else  {
166+     return  scale.const_data_ptr <float >()[channel_ix];
167+   }
168+ }
169+ 
170+ bool  can_use_optimized_dequantize_per_channel (
171+     const  Tensor& in,
172+     const  ScalarType in_dtype,
173+     exec_aten::optional<ScalarType>& out_dtype) {
174+   bool  is_contiguous = false ;
175+ #ifdef  USE_ATEN_LIB
176+   is_contiguous = in.is_contiguous ();
177+ #else 
178+   is_contiguous = executorch::runtime::is_contiguous_dim_order (
179+       in.dim_order ().data (), in.dim ());
180+ #endif 
181+   if  (!is_contiguous || (in_dtype != ScalarType::Char) ||
182+       (out_dtype.has_value () && out_dtype.value () != ScalarType::Float)) {
183+     return  false ;
184+   }
185+   return  true ;
186+ }
187+ 
188+ void  dequantize_per_channel_optimized (
189+     const  Tensor& in,
190+     const  Tensor& scales,
191+     const  optional<Tensor>& opt_zero_points,
192+     Tensor& out,
193+     int64_t  axis,
194+     int64_t  quant_min,
195+     int64_t  quant_max,
196+     ScalarType in_dtype,
197+     exec_aten::optional<ScalarType>& out_dtype) {
198+   check_dequantize_per_tensor_args (
199+       in, quant_min, quant_max, in_dtype, out_dtype, out);
200+   ET_CHECK_MSG (
201+       in_dtype == ScalarType::Char,
202+       " in.scalar_type() %"   PRId8 "  is not supported:"  ,
203+       static_cast <int8_t >(in.scalar_type ()));
204+   if  (out_dtype.has_value ()) {
205+     ET_CHECK_MSG (
206+         out_dtype.value () == ScalarType::Float,
207+         " Only float output is supported"  );
208+   }
209+   const  int8_t * in_data = in.const_data_ptr <int8_t >();
210+   float * out_data = out.mutable_data_ptr <float >();
211+   const  int64_t * zero_points_data = nullptr ;
212+   if  (opt_zero_points.has_value ()) {
213+     zero_points_data = opt_zero_points.value ().const_data_ptr <int64_t >();
214+   }
215+   const  StridesType axis_stride = in.strides ()[axis];
216+   const  StridesType outer_stride = in.size (axis) * axis_stride;
217+   apply_over_unpacked_dim (
218+       [in_data,
219+        out_data,
220+        &scales,
221+        zero_points_data,
222+        axis_stride,
223+        outer_stride,
224+        quant_min,
225+        quant_max](
226+           SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
227+         const  int8_t * in_data_local =
228+             in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
229+         const  double  scale = get_scale (scales, unpacked_dim_idx);
230+         const  int64_t  zero_point = zero_points_data != nullptr 
231+             ? zero_points_data[unpacked_dim_idx]
232+             : 0 ;
233+         float * out_data_local = out_data + outer_idx * outer_stride +
234+             unpacked_dim_idx * axis_stride;
235+         dequantize_optimized (
236+             in_data_local,
237+             scale,
238+             zero_point,
239+             out_data_local,
240+             quant_min,
241+             quant_max,
242+             numel);
243+       },
244+       in,
245+       axis);
246+ }
247+ 
66248} //  namespace
67249
68250/* *
@@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
172354  return  out;
173355}
174356
175- float  get_scale (const  Tensor& scale, size_t  channel_ix) {
176-   ET_CHECK_MSG (
177-       (scale.scalar_type () == ScalarType::Double) ||
178-           (scale.scalar_type () == ScalarType::Float),
179-       " scale.scalar_type() %"   PRId8 "  is not double or float type"  ,
180-       static_cast <int8_t >(scale.scalar_type ()));
181-   if  (scale.scalar_type () == ScalarType::Double) {
182-     return  static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
183-   } else  {
184-     return  scale.const_data_ptr <float >()[channel_ix];
185-   }
186- }
187- 
188357Tensor& dequantize_per_channel_out (
189358    const  Tensor& input,
190359    const  Tensor& scale,
@@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out(
229398  check_dequantize_per_tensor_args (
230399      input, quant_min, quant_max, dtype, out_dtype, out);
231400
401+   if  (can_use_optimized_dequantize_per_channel (input, dtype, out_dtype)) {
402+     dequantize_per_channel_optimized (
403+         input,
404+         scale,
405+         opt_zero_points,
406+         out,
407+         axis,
408+         quant_min,
409+         quant_max,
410+         dtype,
411+         out_dtype);
412+     return  out;
413+   }
414+ 
232415  //  a list contains all dimensions except axis
233416  int64_t  dims[kTensorDimensionLimit ];
234417  for  (int64_t  i = 0 ; i < input.dim () - 1 ; i++) {
0 commit comments