@@ -41,13 +41,13 @@ void check_quantize_args(
4141 " input.scalar_type() %" PRId8 " is not float type" ,
4242 static_cast <int8_t >(input.scalar_type ()));
4343
44- // Check output dtype is int8 (Char)
44+ // Check output dtype is int8
4545 ET_CHECK_MSG (
4646 out.scalar_type () == ScalarType::Char,
4747 " out.scalar_type() %" PRId8 " is not int8 (Char)" ,
4848 static_cast <int8_t >(out.scalar_type ()));
4949
50- // Check dtype is int8 (Char)
50+ // Check dtype is int8
5151 ET_CHECK_MSG (
5252 dtype == ScalarType::Char,
5353 " dtype %" PRId8 " is not int8 (Char)" ,
@@ -75,18 +75,18 @@ void check_quantize_args(
7575/* *
7676 * Scalar implementation of quantization for a single value.
7777 */
78- template <typename T , typename K >
79- T quantize_val (
80- float inv_scale,
78+ template <typename Q , typename F >
79+ Q quantize_val (
80+ F inv_scale,
8181 int32_t zero_point,
82- K value,
82+ F value,
8383 int64_t quant_min,
8484 int64_t quant_max) {
8585 int32_t qvalue =
8686 zero_point + static_cast <int32_t >(std::nearbyint (inv_scale * value));
8787 qvalue = std::max<int32_t >(qvalue, static_cast <int32_t >(quant_min));
8888 qvalue = std::min<int32_t >(qvalue, static_cast <int32_t >(quant_max));
89- return static_cast <T >(qvalue);
89+ return static_cast <Q >(qvalue);
9090}
9191
9292} // namespace
@@ -123,16 +123,97 @@ Tensor& quantize_per_tensor_out(
123123 int8_t * out_data = out.mutable_data_ptr <int8_t >();
124124 const size_t numel = input.numel ();
125125
126+ size_t i = 0 ;
127+
126128#if defined(HAS_HELIUM_SIMD)
127- // Helium MVE implementation for float32 to int8 quantization
128- #Error " Implement MVE version!"
129- #else
130- // Scalar implementation for float32 to int8 quantization
131- for (size_t i = 0 ; i < numel; i++) {
132- out_data[i] =
133- quantize_val<int8_t , float >(inv_scale, zp, input_data[i], qmin, qmax);
129+ // Helium MVE implementation for float32 to int8 quantization
130+ static uint8x16_t voffset{
131+ 0x0 ,
132+ 0x8 ,
133+ 0x4 ,
134+ 0xC ,
135+ 0x1 ,
136+ 0x9 ,
137+ 0x5 ,
138+ 0xD ,
139+ 0x2 ,
140+ 0xA ,
141+ 0x6 ,
142+ 0xE ,
143+ 0x3 ,
144+ 0xB ,
145+ 0x7 ,
146+ 0xF };
147+
148+ float32x4_t inv_scale_vec = vdupq_n_f32 (inv_scale);
149+
150+ // Magic number for float to int conversion, round to nearest even integer
151+ // int magic_round(float f): interpret_as_int32(f + magic_float) - magic_int
152+ // where,
153+ // magic_float = 12582912.0f = (2 ** 23 + 2 ** 22) = (1.5 * 2 ** 23)
154+ // magic_int = 1262485504 = 0x4B400000 = bit_pattern_as_int32(magic_float)
155+
156+ float magic_float = 12582912 .0f ;
157+ int32_t magic_int = 1262485504 ;
158+
159+ float32x4_t vmagic_float = vdupq_n_f32 (magic_float);
160+ int32x4_t vmagic_int_less_zp =
161+ vdupq_n_s32 (magic_int - static_cast <int32_t >(zp));
162+
163+ int16x8_t vqmin = vdupq_n_s16 (qmin);
164+ int16x8_t vqmax = vdupq_n_s16 (qmax);
165+
166+ // TODO: Measure performnce, we are spilling
167+ for (; i + 15 < numel; i += 16 ) {
168+ float32x4_t in_0123 = vldrwq_f32 (input_data + 0 );
169+ float32x4_t in_4567 = vldrwq_f32 (input_data + 4 );
170+ float32x4_t in_89AB = vldrwq_f32 (input_data + 8 );
171+ float32x4_t in_CDEF = vldrwq_f32 (input_data + 12 );
172+
173+ float32x4_t outf_0123 = vfmaq_f32 (vmagic_float, in_0123, inv_scale_vec);
174+ float32x4_t outf_4567 = vfmaq_f32 (vmagic_float, in_4567, inv_scale_vec);
175+ float32x4_t outf_89AB = vfmaq_f32 (vmagic_float, in_89AB, inv_scale_vec);
176+ float32x4_t outf_CDEF = vfmaq_f32 (vmagic_float, in_CDEF, inv_scale_vec);
177+
178+ int32x4_t out_0123 =
179+ vsubq_s32 (vreinterpretq_s32_f32 (outf_0123), vmagic_int_less_zp);
180+ int32x4_t out_4567 =
181+ vsubq_s32 (vreinterpretq_s32_f32 (outf_4567), vmagic_int_less_zp);
182+ int32x4_t out_89AB =
183+ vsubq_s32 (vreinterpretq_s32_f32 (outf_89AB), vmagic_int_less_zp);
184+ int32x4_t out_CDEF =
185+ vsubq_s32 (vreinterpretq_s32_f32 (outf_CDEF), vmagic_int_less_zp);
186+
187+ int16x8_t out_04152637;
188+ int16x8_t out_8C9DAEBF;
189+ out_04152637 = vmovnbq_s32 (out_04152637, out_0123);
190+ out_04152637 = vmovntq_s32 (out_04152637, out_4567);
191+ out_8C9DAEBF = vmovnbq_s32 (out_8C9DAEBF, out_89AB);
192+ out_8C9DAEBF = vmovntq_s32 (out_8C9DAEBF, out_CDEF);
193+
194+ int16x8_t out_04152637_clamped =
195+ vminq_s16 (vmaxq_s16 (out_04152637, vqmin), vqmax);
196+ int16x8_t out_8C9DAEBF_clamped =
197+ vminq_s16 (vmaxq_s16 (out_8C9DAEBF, vqmin), vqmax);
198+
199+ int8x16_t out_084C195D2A6E3B7F;
200+ out_084C195D2A6E3B7F =
201+ vmovnbq_s16 (out_084C195D2A6E3B7F, out_04152637_clamped);
202+ out_084C195D2A6E3B7F =
203+ vmovntq_s16 (out_084C195D2A6E3B7F, out_8C9DAEBF_clamped);
204+
205+ vstrbq_scatter_offset_s8 (out_data, voffset, out_084C195D2A6E3B7F);
206+ input_data += 16 ;
207+ out_data += 16 ;
208+ }
209+ #endif // defined(HAS_HELIUM_SIMD)
210+
211+ for (; i < numel; i++) {
212+ *out_data =
213+ quantize_val<int8_t , float >(inv_scale, zp, *input_data, qmin, qmax);
214+ input_data++;
215+ out_data++;
134216 }
135- #endif
136217
137218 return out;
138219}
0 commit comments