@@ -41,13 +41,13 @@ void check_quantize_args(
41
41
" input.scalar_type() %" PRId8 " is not float type" ,
42
42
static_cast <int8_t >(input.scalar_type ()));
43
43
44
- // Check output dtype is int8 (Char)
44
+ // Check output dtype is int8
45
45
ET_CHECK_MSG (
46
46
out.scalar_type () == ScalarType::Char,
47
47
" out.scalar_type() %" PRId8 " is not int8 (Char)" ,
48
48
static_cast <int8_t >(out.scalar_type ()));
49
49
50
- // Check dtype is int8 (Char)
50
+ // Check dtype is int8
51
51
ET_CHECK_MSG (
52
52
dtype == ScalarType::Char,
53
53
" dtype %" PRId8 " is not int8 (Char)" ,
@@ -75,18 +75,18 @@ void check_quantize_args(
75
75
/* *
76
76
* Scalar implementation of quantization for a single value.
77
77
*/
78
- template <typename T , typename K >
79
- T quantize_val (
80
- float inv_scale,
78
+ template <typename Q , typename F >
79
+ Q quantize_val (
80
+ F inv_scale,
81
81
int32_t zero_point,
82
- K value,
82
+ F value,
83
83
int64_t quant_min,
84
84
int64_t quant_max) {
85
85
int32_t qvalue =
86
86
zero_point + static_cast <int32_t >(std::nearbyint (inv_scale * value));
87
87
qvalue = std::max<int32_t >(qvalue, static_cast <int32_t >(quant_min));
88
88
qvalue = std::min<int32_t >(qvalue, static_cast <int32_t >(quant_max));
89
- return static_cast <T >(qvalue);
89
+ return static_cast <Q >(qvalue);
90
90
}
91
91
92
92
} // namespace
@@ -123,16 +123,97 @@ Tensor& quantize_per_tensor_out(
123
123
int8_t * out_data = out.mutable_data_ptr <int8_t >();
124
124
const size_t numel = input.numel ();
125
125
126
+ size_t i = 0 ;
127
+
126
128
#if defined(HAS_HELIUM_SIMD)
127
- // Helium MVE implementation for float32 to int8 quantization
128
- #Error " Implement MVE version!"
129
- #else
130
- // Scalar implementation for float32 to int8 quantization
131
- for (size_t i = 0 ; i < numel; i++) {
132
- out_data[i] =
133
- quantize_val<int8_t , float >(inv_scale, zp, input_data[i], qmin, qmax);
129
+ // Helium MVE implementation for float32 to int8 quantization
130
+ static uint8x16_t voffset{
131
+ 0x0 ,
132
+ 0x8 ,
133
+ 0x4 ,
134
+ 0xC ,
135
+ 0x1 ,
136
+ 0x9 ,
137
+ 0x5 ,
138
+ 0xD ,
139
+ 0x2 ,
140
+ 0xA ,
141
+ 0x6 ,
142
+ 0xE ,
143
+ 0x3 ,
144
+ 0xB ,
145
+ 0x7 ,
146
+ 0xF };
147
+
148
+ float32x4_t inv_scale_vec = vdupq_n_f32 (inv_scale);
149
+
150
+ // Magic number for float to int conversion, round to nearest even integer
151
+ // int magic_round(float f): interpret_as_int32(f + magic_float) - magic_int
152
+ // where,
153
+ // magic_float = 12582912.0f = (2 ** 23 + 2 ** 22) = (1.5 * 2 ** 23)
154
+ // magic_int = 1262485504 = 0x4B400000 = bit_pattern_as_int32(magic_float)
155
+
156
+ float magic_float = 12582912 .0f ;
157
+ int32_t magic_int = 1262485504 ;
158
+
159
+ float32x4_t vmagic_float = vdupq_n_f32 (magic_float);
160
+ int32x4_t vmagic_int_less_zp =
161
+ vdupq_n_s32 (magic_int - static_cast <int32_t >(zp));
162
+
163
+ int16x8_t vqmin = vdupq_n_s16 (qmin);
164
+ int16x8_t vqmax = vdupq_n_s16 (qmax);
165
+
166
+ // TODO: Measure performnce, we are spilling
167
+ for (; i + 15 < numel; i += 16 ) {
168
+ float32x4_t in_0123 = vldrwq_f32 (input_data + 0 );
169
+ float32x4_t in_4567 = vldrwq_f32 (input_data + 4 );
170
+ float32x4_t in_89AB = vldrwq_f32 (input_data + 8 );
171
+ float32x4_t in_CDEF = vldrwq_f32 (input_data + 12 );
172
+
173
+ float32x4_t outf_0123 = vfmaq_f32 (vmagic_float, in_0123, inv_scale_vec);
174
+ float32x4_t outf_4567 = vfmaq_f32 (vmagic_float, in_4567, inv_scale_vec);
175
+ float32x4_t outf_89AB = vfmaq_f32 (vmagic_float, in_89AB, inv_scale_vec);
176
+ float32x4_t outf_CDEF = vfmaq_f32 (vmagic_float, in_CDEF, inv_scale_vec);
177
+
178
+ int32x4_t out_0123 =
179
+ vsubq_s32 (vreinterpretq_s32_f32 (outf_0123), vmagic_int_less_zp);
180
+ int32x4_t out_4567 =
181
+ vsubq_s32 (vreinterpretq_s32_f32 (outf_4567), vmagic_int_less_zp);
182
+ int32x4_t out_89AB =
183
+ vsubq_s32 (vreinterpretq_s32_f32 (outf_89AB), vmagic_int_less_zp);
184
+ int32x4_t out_CDEF =
185
+ vsubq_s32 (vreinterpretq_s32_f32 (outf_CDEF), vmagic_int_less_zp);
186
+
187
+ int16x8_t out_04152637;
188
+ int16x8_t out_8C9DAEBF;
189
+ out_04152637 = vmovnbq_s32 (out_04152637, out_0123);
190
+ out_04152637 = vmovntq_s32 (out_04152637, out_4567);
191
+ out_8C9DAEBF = vmovnbq_s32 (out_8C9DAEBF, out_89AB);
192
+ out_8C9DAEBF = vmovntq_s32 (out_8C9DAEBF, out_CDEF);
193
+
194
+ int16x8_t out_04152637_clamped =
195
+ vminq_s16 (vmaxq_s16 (out_04152637, vqmin), vqmax);
196
+ int16x8_t out_8C9DAEBF_clamped =
197
+ vminq_s16 (vmaxq_s16 (out_8C9DAEBF, vqmin), vqmax);
198
+
199
+ int8x16_t out_084C195D2A6E3B7F;
200
+ out_084C195D2A6E3B7F =
201
+ vmovnbq_s16 (out_084C195D2A6E3B7F, out_04152637_clamped);
202
+ out_084C195D2A6E3B7F =
203
+ vmovntq_s16 (out_084C195D2A6E3B7F, out_8C9DAEBF_clamped);
204
+
205
+ vstrbq_scatter_offset_s8 (out_data, voffset, out_084C195D2A6E3B7F);
206
+ input_data += 16 ;
207
+ out_data += 16 ;
208
+ }
209
+ #endif // defined(HAS_HELIUM_SIMD)
210
+
211
+ for (; i < numel; i++) {
212
+ *out_data =
213
+ quantize_val<int8_t , float >(inv_scale, zp, *input_data, qmin, qmax);
214
+ input_data++;
215
+ out_data++;
134
216
}
135
- #endif
136
217
137
218
return out;
138
219
}
0 commit comments