@@ -41,13 +41,13 @@ void check_quantize_args(
4141      " input.scalar_type() %" "  is not float type" 
4242      static_cast <int8_t >(input.scalar_type ()));
4343
44-   //  Check output dtype is int8 (Char) 
44+   //  Check output dtype is int8
4545  ET_CHECK_MSG (
4646      out.scalar_type () == ScalarType::Char,
4747      " out.scalar_type() %" "  is not int8 (Char)" 
4848      static_cast <int8_t >(out.scalar_type ()));
4949
50-   //  Check dtype is int8 (Char) 
50+   //  Check dtype is int8
5151  ET_CHECK_MSG (
5252      dtype == ScalarType::Char,
5353      " dtype %" "  is not int8 (Char)" 
@@ -75,18 +75,18 @@ void check_quantize_args(
7575/* *
7676 * Scalar implementation of quantization for a single value. 
7777 */  
78- template  <typename  T , typename  K >
79- T  quantize_val (
80-     float  inv_scale,
78+ template  <typename  Q , typename  F >
79+ Q  quantize_val (
80+     F  inv_scale,
8181    int32_t  zero_point,
82-     K  value,
82+     F  value,
8383    int64_t  quant_min,
8484    int64_t  quant_max) {
8585  int32_t  qvalue =
8686      zero_point + static_cast <int32_t >(std::nearbyint (inv_scale * value));
8787  qvalue = std::max<int32_t >(qvalue, static_cast <int32_t >(quant_min));
8888  qvalue = std::min<int32_t >(qvalue, static_cast <int32_t >(quant_max));
89-   return  static_cast <T >(qvalue);
89+   return  static_cast <Q >(qvalue);
9090}
9191
9292} //  namespace
@@ -123,16 +123,97 @@ Tensor& quantize_per_tensor_out(
123123  int8_t * out_data = out.mutable_data_ptr <int8_t >();
124124  const  size_t  numel = input.numel ();
125125
126+   size_t  i = 0 ;
127+ 
126128#if  defined(HAS_HELIUM_SIMD)
127- //  Helium MVE implementation for float32 to int8 quantization
128- #Error " Implement MVE version!" 
129- #else 
130-   //  Scalar implementation for float32 to int8 quantization
131-   for  (size_t  i = 0 ; i < numel; i++) {
132-     out_data[i] =
133-         quantize_val<int8_t , float >(inv_scale, zp, input_data[i], qmin, qmax);
129+   //  Helium MVE implementation for float32 to int8 quantization
130+   static  uint8x16_t  voffset{
131+       0x0 ,
132+       0x8 ,
133+       0x4 ,
134+       0xC ,
135+       0x1 ,
136+       0x9 ,
137+       0x5 ,
138+       0xD ,
139+       0x2 ,
140+       0xA ,
141+       0x6 ,
142+       0xE ,
143+       0x3 ,
144+       0xB ,
145+       0x7 ,
146+       0xF };
147+ 
148+   float32x4_t  inv_scale_vec = vdupq_n_f32 (inv_scale);
149+ 
150+   //  Magic number for float to int conversion, round to nearest even integer
151+   //  int magic_round(float f): interpret_as_int32(f + magic_float) - magic_int
152+   //  where,
153+   //     magic_float = 12582912.0f = (2 ** 23 + 2 ** 22) = (1.5 * 2 ** 23)
154+   //     magic_int = 1262485504 = 0x4B400000 = bit_pattern_as_int32(magic_float)
155+ 
156+   float  magic_float = 12582912 .0f ;
157+   int32_t  magic_int = 1262485504 ;
158+ 
159+   float32x4_t  vmagic_float = vdupq_n_f32 (magic_float);
160+   int32x4_t  vmagic_int_less_zp =
161+       vdupq_n_s32 (magic_int - static_cast <int32_t >(zp));
162+ 
163+   int16x8_t  vqmin = vdupq_n_s16 (qmin);
164+   int16x8_t  vqmax = vdupq_n_s16 (qmax);
165+ 
166+   //  TODO: Measure performnce, we are spilling
167+   for  (; i + 15  < numel; i += 16 ) {
168+     float32x4_t  in_0123 = vldrwq_f32 (input_data + 0 );
169+     float32x4_t  in_4567 = vldrwq_f32 (input_data + 4 );
170+     float32x4_t  in_89AB = vldrwq_f32 (input_data + 8 );
171+     float32x4_t  in_CDEF = vldrwq_f32 (input_data + 12 );
172+ 
173+     float32x4_t  outf_0123 = vfmaq_f32 (vmagic_float, in_0123, inv_scale_vec);
174+     float32x4_t  outf_4567 = vfmaq_f32 (vmagic_float, in_4567, inv_scale_vec);
175+     float32x4_t  outf_89AB = vfmaq_f32 (vmagic_float, in_89AB, inv_scale_vec);
176+     float32x4_t  outf_CDEF = vfmaq_f32 (vmagic_float, in_CDEF, inv_scale_vec);
177+ 
178+     int32x4_t  out_0123 =
179+         vsubq_s32 (vreinterpretq_s32_f32 (outf_0123), vmagic_int_less_zp);
180+     int32x4_t  out_4567 =
181+         vsubq_s32 (vreinterpretq_s32_f32 (outf_4567), vmagic_int_less_zp);
182+     int32x4_t  out_89AB =
183+         vsubq_s32 (vreinterpretq_s32_f32 (outf_89AB), vmagic_int_less_zp);
184+     int32x4_t  out_CDEF =
185+         vsubq_s32 (vreinterpretq_s32_f32 (outf_CDEF), vmagic_int_less_zp);
186+ 
187+     int16x8_t  out_04152637;
188+     int16x8_t  out_8C9DAEBF;
189+     out_04152637 = vmovnbq_s32 (out_04152637, out_0123);
190+     out_04152637 = vmovntq_s32 (out_04152637, out_4567);
191+     out_8C9DAEBF = vmovnbq_s32 (out_8C9DAEBF, out_89AB);
192+     out_8C9DAEBF = vmovntq_s32 (out_8C9DAEBF, out_CDEF);
193+ 
194+     int16x8_t  out_04152637_clamped =
195+         vminq_s16 (vmaxq_s16 (out_04152637, vqmin), vqmax);
196+     int16x8_t  out_8C9DAEBF_clamped =
197+         vminq_s16 (vmaxq_s16 (out_8C9DAEBF, vqmin), vqmax);
198+ 
199+     int8x16_t  out_084C195D2A6E3B7F;
200+     out_084C195D2A6E3B7F =
201+         vmovnbq_s16 (out_084C195D2A6E3B7F, out_04152637_clamped);
202+     out_084C195D2A6E3B7F =
203+         vmovntq_s16 (out_084C195D2A6E3B7F, out_8C9DAEBF_clamped);
204+ 
205+     vstrbq_scatter_offset_s8 (out_data, voffset, out_084C195D2A6E3B7F);
206+     input_data += 16 ;
207+     out_data += 16 ;
208+   }
209+ #endif  //  defined(HAS_HELIUM_SIMD)
210+ 
211+   for  (; i < numel; i++) {
212+     *out_data =
213+         quantize_val<int8_t , float >(inv_scale, zp, *input_data, qmin, qmax);
214+     input_data++;
215+     out_data++;
134216  }
135- #endif 
136217
137218  return  out;
138219}
0 commit comments