Skip to content

Commit 5e60c47

Browse files
committed
[rnn_dense] fix rnn_dense accuracy issues
1 parent 06ecf15 commit 5e60c47

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

lib/src/bricks/impl/mli_krn_rnn_dense_op_vdsp.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@ static inline void rnn_dense_op(
109109
const io_T val_max_limit) {
110110

111111
int num_lanes = get_number_lanes<acc_T>();
112+
112113
for (int o_idx = 0; o_idx < out_elements; o_idx += num_lanes) {
113114
int remaining_ch = out_elements - o_idx;
114115
int current_chs = MIN(remaining_ch, num_lanes); // number of channels computed in this loop iteration
115116

116-
acc_T accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
117-
acc_T prev_step = mli_math_mul_fx<io_T, acc_T>(0, 0);
117+
acc_T accu = mli_prv_init_accu<acc_T>();
118+
acc_T prev_step = mli_prv_init_accu<acc_T>();
118119

119120
auto output_params = adjust_quant_params_v(&in_to_out_quant_params[0], 0);
120121
accu = mli::krn::bias_additive(&bias[o_idx], accu, &output_params, /* add_preshift_rnd */ false);
@@ -130,7 +131,7 @@ static inline void rnn_dense_op(
130131
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], o_idx);
131132
prev_step = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
132133
&in_to_out_quant_params[idx + 1], /* krn_idx= */ 0);
133-
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
134+
accu = mli_prv_init_accu<acc_T>();
134135
} else {
135136
// Cast result to output type with scaling
136137
mli::krn::result_cast_relu_store_v(&out[o_idx], accu, &output_params,

lib/src/bricks/impl/mli_prv_quant_vdsp.h

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,42 @@ MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
512512
vNx4int_t acc_norm = mli_math_norm_fx<vNx4int_t, vNx4int_t>(acc_int);
513513
acc_int = mli_math_asl_fx<vNx4int_t, vNx4int_t>(acc_int, acc_norm);
514514

515-
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm + shift));
516515
vNx4int_t acc_scaled = mli_math_mul_fx_high(acc_int, mul_shifted);
517-
vNx4int_t acc_shifted = mli_math_asr_rnd_fx(acc_scaled, total_shift);
518516

517+
constexpr int mul_high_shift = 32;
518+
constexpr int max_int_shift = 30;
519+
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm - mul_high_shift + shift));
520+
vNx4int_t shift_left = mli_math_max_fx(-total_shift, 0);
521+
vNx4int_t shift_right = mli_math_max_fx(total_shift, 0);
522+
523+
vNx4int_t preshift = mli_math_max_fx(shift_right - max_int_shift, 0);
524+
shift_right = shift_right - preshift;
525+
526+
vNx4int_t acc_shifted = mli_math_asr_fx(acc_scaled, preshift);
527+
acc_shifted = mli_math_asr_rnd_fx(acc_shifted, shift_right);
528+
acc_shifted = mli_math_asl_fx(acc_shifted, shift_left);
529+
530+
#if (__Xvec_guard_bit_option == 0)
531+
vNx4short_t acc_short = mli_math_cast_fx<vNx4int_t, vNx4short_t>(acc_shifted);
532+
vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0);
533+
#else
519534
vNx4int_t norm;
520535
vNx4short_t acc_short = mli_math_norm_cast_fx</*left_shift*/ false>(acc_shifted , &norm);
536+
537+
constexpr int guard_bits = 8;
538+
vNx4int_t mask = (1 << norm) - 1;
539+
vNx4int_t acc_shifted_low = acc_shifted & mask;
540+
// If the norm is more than the number of guardbits,
541+
// so the masked_acc has to be shifted, since the result is shifted with max shift equals to number of guardbits.
542+
vNx4int_t mask_shift = mli_math_max_fx(norm - guard_bits, 0);
543+
acc_shifted_low = mli_math_asr_fx(acc_shifted_low, mask_shift);
544+
545+
norm = mli_math_min_fx(norm, guard_bits);
521546
vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0);
522547
res = mli_math_asl_fx<vNx4accshort_t, vNx4short_t>(res, to_vNx4short_t(norm));
548+
res = mli_math_add(res, to_vNx4short_t(acc_shifted_low));
549+
#endif
550+
523551
return res;
524552
}
525553

0 commit comments

Comments
 (0)