@@ -117,7 +117,6 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_interpolate(
117117 int8_t in_frac_bits,
118118 const struct s8asym_quant_params *in_params) {
119119
120- MLI_ASSERT (in_frac_bits >= -1 ); // -1 may be required by softmax
121120 MLI_ASSERT (lut->in_frac_bits >= 0 );
122121 MLI_ASSERT (lut->length >= 0 );
123122
@@ -133,26 +132,36 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_interpolate(
133132 const MLI_PTR (short ) lut_data = (const MLI_PTR (short ))lut->data .mem .pi16 ;
134133 // if shift amount is too high, preshift argument itself and
135134 // limit shift amount to prevent overflows
135+ constexpr int max_shift = 15 ;
136136 int preshift_in = mli_math_max_fx (shift_in - (int )kMaxFracBitsFx16 , 0 );
137+ preshift_in = mli_math_min_fx (preshift_in, max_shift);
137138 shift_in = mli_math_min_fx (shift_in, (int )kMaxFracBitsFx16 );
138139
139140 // input data is more precise than LUT
140141 int16_t mask = (1 << shift_in) - 1 ;
141142 vNx4short_t x = in;
143+ vNx4int_t lut_idx_int;
144+ vNx4short_t frac;
142145 if (convert) {
143- int shift = ((int32_t ) in_params->shift - in_frac_bits) + preshift_in;
144- x = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4short_t>(x, in_params->offset , in_params->scale , shift);
146+ int shift = (int32_t ) in_params->shift - in_frac_bits;
147+ vNx4int_t x_int = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4int_t>(x, in_params->offset , in_params->scale , shift);
148+ x_int = mli_math_asr_fx (x_int, preshift_in);
149+ frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
150+
151+ /* Calculate lut_idx */
152+ vNx4int_t lut_idx = mli_math_add_fx<vNx4int_t>(mli_math_asr_fx (x_int, shift_in), lut->input_offset );
153+ lut_idx_int = mli_math_bound_range_fx (lut_idx , 0 , lut->length - 2 );
154+
145155 } else {
146- constexpr int max_shift = 15 ;
147- preshift_in = mli_math_min_fx (preshift_in, max_shift);
148156 x = mli_math_asr_fx (x, preshift_in);
157+ frac = x & mask;
158+
159+ /* Calculate lut_idx */
160+ vNx4short_t lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx (x, shift_in), lut->input_offset );
161+ lut_idx = mli_math_bound_range_fx (lut_idx , 0 , lut->length - 2 );
162+ lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1 );
149163 }
150- vNx4short_t lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx (x, shift_in), lut->input_offset );
151- /* Calculate lut_idx */
152- lut_idx = mli_math_bound_range_fx (lut_idx , 0 , lut->length - 2 );
153- vNx4int_t lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1 );
154164
155- vNx4short_t frac = x & mask;
156165 /* Load from LUT */
157166 vNx4short_t lut_values = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int);
158167 vNx4short_t lut_values_next = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int + 1 );
@@ -162,9 +171,7 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_interpolate(
162171 mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, frac), shift_in);
163172
164173 /* Calculate O/P */
165- vNx4short_t result = mli_math_sub_fx<vNx4short_t>(lut_values, diffs_mul_frac_cast);
166-
167- return result;
174+ return mli_math_sub_fx<vNx4short_t>(lut_values, diffs_mul_frac_cast);
168175}
169176
170177template <bool convert>
@@ -174,7 +181,6 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_no_interpolate(
174181 int8_t in_frac_bits,
175182 const struct s8asym_quant_params *in_params) {
176183
177- MLI_ASSERT (in_frac_bits >= -1 ); // -1 may be required by softmax
178184 MLI_ASSERT (lut->in_frac_bits >= 0 );
179185 MLI_ASSERT (lut->length >= 0 );
180186
@@ -213,30 +219,33 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_no_interpolate(
213219template <typename io_T, bool convert>
214220static MLI_FORCE_INLINE void load_input_and_get_lut_idx (
215221 MLI_PTR (io_T) __restrict in_ptr,
216- vNx4short_t &vec ,
217- vNx4short_t &lut_idx ,
222+ vNx4short_t &x ,
223+ vNx4int_t &x_int ,
218224 vNx4int_t &lut_idx_int,
219225 int16_t in_frac_bits,
220226 int preshift_in,
221227 int shift_in,
222228 const mli_lut *lut,
223229 const struct s8asym_quant_params *in_params) {
224230
225- vec = activation_lut_load_input<io_T, vNx4short_t>(in_ptr);
226-
231+ x = activation_lut_load_input<io_T, vNx4short_t>(in_ptr);
227232 if (convert) {
228- int shift = ((int32_t ) in_params->shift - in_frac_bits) + preshift_in;
229- vec = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4short_t>(vec, in_params->offset , in_params->scale , shift);
233+ int shift = (int32_t ) in_params->shift - in_frac_bits;
234+ x_int = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4int_t>(x, in_params->offset , in_params->scale , shift);
235+ x_int = mli_math_asr_fx (x_int, preshift_in);
236+
237+ /* Calculate lut_idx */
238+ vNx4int_t lut_idx = mli_math_add_fx<vNx4int_t>(mli_math_asr_fx (x_int, shift_in), lut->input_offset );
239+ lut_idx_int = mli_math_bound_range_fx (lut_idx , 0 , lut->length - 2 );
240+
230241 } else {
231- constexpr int max_shift = 15 ;
232- preshift_in = mli_math_min_fx (preshift_in, max_shift);
233- vec = mli_math_asr_fx (vec, preshift_in);
234- }
242+ x = mli_math_asr_fx (x, preshift_in);
235243
236- /* Calculate lut_idx */
237- lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx (vec, shift_in), lut->input_offset );
238- lut_idx = mli_math_bound_range_fx (lut_idx , 0 , lut->length - 2 );
239- lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1 );
244+ /* Calculate lut_idx */
245+ vNx4short_t lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx (x, shift_in), lut->input_offset );
246+ lut_idx = mli_math_bound_range_fx (lut_idx , 0 , lut->length - 2 );
247+ lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1 );
248+ }
240249}
241250
242251template <typename io_T, bool convert>
@@ -250,7 +259,6 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
250259 const struct s8asym_quant_params *in_params,
251260 struct s8asym_quant_params *out_params) {
252261
253- MLI_ASSERT (in_frac_bits >= -1 ); // -1 may be required by softmax
254262 MLI_ASSERT (lut->in_frac_bits >= 0 );
255263 MLI_ASSERT (lut->length >= 0 );
256264 MLI_ASSERT (MLI_MAX_RANK == 4 );
@@ -271,7 +279,9 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
271279 const MLI_PTR (short ) lut_data = (const MLI_PTR (short ))lut->data .mem .pi16 ;
272280 // if shift amount is too high, preshift argument itself and
273281 // limit shift amount to prevent overflows
282+ constexpr int max_shift = 15 ;
274283 int preshift_in = mli_math_max_fx (shift_in - (int )kMaxFracBitsFx16 , 0 );
284+ preshift_in = mli_math_min_fx (preshift_in, max_shift);
275285 shift_in = mli_math_min_fx (shift_in, (int )kMaxFracBitsFx16 );
276286
277287 int remaining_part = in->shape [3 ] & (_VDSP_NUM_8BIT_LANES - 1 );
@@ -294,49 +304,61 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
294304 }
295305
296306 /* Manual software pipelining */
297- vNx4short_t x, lut_idx ;
298- vNx4int_t lut_idx_int;
307+ vNx4short_t x, frac ;
308+ vNx4int_t x_int, lut_idx_int;
299309 vNx4short_t _lut_values, _lut_values_next, _frac;
300310
301- load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx , lut_idx_int,
311+ load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int , lut_idx_int,
302312 in_frac_bits, preshift_in, shift_in, lut, in_params);
303- vNx4short_t frac = x & mask;
313+ if (convert) {
314+ frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
315+ } else {
316+ frac = x & mask;
317+ }
304318 input_ptr += _VDSP_NUM_8BIT_LANES;
305319
306320 if (in->shape [3 ] >= _VDSP_NUM_8BIT_LANES && !convert) {
307321 /* Load from LUT */
308322 _lut_values = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int);
309323 _lut_values_next = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int + 1 );
310324
311- load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx , lut_idx_int,
325+ load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int , lut_idx_int,
312326 in_frac_bits, preshift_in, shift_in, lut, in_params);
313327
314328 _frac = frac;
315- frac = x & mask;
329+ if (convert) {
330+ frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
331+ } else {
332+ frac = x & mask;
333+ }
316334 input_ptr += _VDSP_NUM_8BIT_LANES;
317335
318336 for (int pos3 = remaining_part; pos3 < in->shape [3 ] - _VDSP_NUM_8BIT_LANES; pos3 += _VDSP_NUM_8BIT_LANES) {
319337 /* Load from LUT */
320338 vNx4short_t lut_values = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int);
321339 vNx4short_t lut_values_next = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int + 1 );
322340
323- load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx , lut_idx_int,
341+ load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int , lut_idx_int,
324342 in_frac_bits, preshift_in, shift_in, lut, in_params);
325343
326- /* perform linear interpolation */
327- vNx4short_t diffs = mli_math_sub_fx<vNx4short_t>(_lut_values, _lut_values_next);
328- vNx4short_t diffs_mul_frac_cast = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(
329- mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, _frac), shift_in);
344+ /* perform linear interpolation */
345+ vNx4short_t diffs = mli_math_sub_fx<vNx4short_t>(_lut_values, _lut_values_next);
346+ vNx4short_t diffs_mul_frac_cast = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(
347+ mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, _frac), shift_in);
330348
331- /* Calculate O/P */
332- vNx4short_t res = mli_math_sub_fx<vNx4short_t>(_lut_values, diffs_mul_frac_cast);
349+ /* Calculate O/P */
350+ vNx4short_t res = mli_math_sub_fx<vNx4short_t>(_lut_values, diffs_mul_frac_cast);
333351
334- /* Store O/P */
335- activation_lut_store_output<io_T, convert>(output_ptr, res, lut, out_params);
336- output_ptr += _VDSP_NUM_8BIT_LANES;
352+ /* Store O/P */
353+ activation_lut_store_output<io_T, convert>(output_ptr, res, lut, out_params);
354+ output_ptr += _VDSP_NUM_8BIT_LANES;
337355
338356 _frac = frac;
339- frac = x & mask;
357+ if (convert) {
358+ frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
359+ } else {
360+ frac = x & mask;
361+ }
340362
341363 input_ptr += _VDSP_NUM_8BIT_LANES;
342364
@@ -360,15 +382,19 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
360382 vNx4short_t lut_values = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int);
361383 vNx4short_t lut_values_next = mli_prv_gather_load_nx4_samples (lut_data, lut_idx_int + 1 );
362384
363- load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx , lut_idx_int,
385+ load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int , lut_idx_int,
364386 in_frac_bits, preshift_in, shift_in, lut, in_params);
365387
366388 /* perform linear interpolation */
367389 vNx4short_t diffs = mli_math_sub_fx<vNx4short_t>(lut_values, lut_values_next);
368390 vNx4short_t diffs_mul_frac_cast = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(
369391 mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, frac), shift_in);
370392
371- frac = x & mask;
393+ if (convert) {
394+ frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
395+ } else {
396+ frac = x & mask;
397+ }
372398
373399 /* Calculate O/P */
374400 vNx4short_t res = mli_math_sub_fx<vNx4short_t>(lut_values, diffs_mul_frac_cast);
0 commit comments