@@ -512,14 +512,42 @@ MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
512512 vNx4int_t acc_norm = mli_math_norm_fx<vNx4int_t, vNx4int_t>(acc_int);
513513 acc_int = mli_math_asl_fx<vNx4int_t, vNx4int_t>(acc_int, acc_norm);
514514
515- vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm + shift));
516515 vNx4int_t acc_scaled = mli_math_mul_fx_high (acc_int, mul_shifted);
517- vNx4int_t acc_shifted = mli_math_asr_rnd_fx (acc_scaled, total_shift);
518516
517+ constexpr int mul_high_shift = 32 ;
518+ constexpr int max_int_shift = 30 ;
519+ vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm - mul_high_shift + shift));
520+ vNx4int_t shift_left = mli_math_max_fx (-total_shift, 0 );
521+ vNx4int_t shift_right = mli_math_max_fx (total_shift, 0 );
522+
523+ vNx4int_t preshift = mli_math_max_fx (shift_right - max_int_shift, 0 );
524+ shift_right = shift_right - preshift;
525+
526+ vNx4int_t acc_shifted = mli_math_asr_fx (acc_scaled, preshift);
527+ acc_shifted = mli_math_asr_rnd_fx (acc_shifted, shift_right);
528+ acc_shifted = mli_math_asl_fx (acc_shifted, shift_left);
529+
530+ #if (__Xvec_guard_bit_option == 0)
531+ vNx4short_t acc_short = mli_math_cast_fx<vNx4int_t, vNx4short_t>(acc_shifted);
532+ vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0 );
533+ #else
519534 vNx4int_t norm;
520535 vNx4short_t acc_short = mli_math_norm_cast_fx</* left_shift*/ false >(acc_shifted , &norm);
536+
537+ constexpr int guard_bits = 8 ;
538+ vNx4int_t mask = (1 << norm) - 1 ;
539+ vNx4int_t acc_shifted_low = acc_shifted & mask;
540+ // If the norm is more than the number of guardbits,
541+ // so the masked_acc has to be shifted, since the result is shifted with max shift equals to number of guardbits.
542+ vNx4int_t mask_shift = mli_math_max_fx (norm - guard_bits, 0 );
543+ acc_shifted_low = mli_math_asr_fx (acc_shifted_low, mask_shift);
544+
545+ norm = mli_math_min_fx (norm, guard_bits);
521546 vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0 );
522547 res = mli_math_asl_fx<vNx4accshort_t, vNx4short_t>(res, to_vNx4short_t (norm));
548+ res = mli_math_add (res, to_vNx4short_t (acc_shifted_low));
549+ #endif
550+
523551 return res;
524552}
525553
0 commit comments