@@ -94,6 +94,31 @@ static inline void rnn_dense_op_stacked(
9494 dense_out_ptr -= gates_num * out_elements;
9595}
9696
97+ MLI_FORCE_INLINE vNx4int_t mli_math_add_accus (vNx4int_t L, vNx4int_t R) {
98+ return mli_math_add_fx (L, R);
99+ }
100+
101+ MLI_FORCE_INLINE vNx2accint_t mli_math_add_accus (vNx2accint_t L, vNx2accint_t R) {
102+ return mli_math_add (L, R);
103+ }
104+
105+ MLI_FORCE_INLINE vNx4accint_t mli_math_add_accus (vNx4accint_t L, vNx4accint_t R) {
106+ return mli_math_add (L, R);
107+ }
108+
109+ MLI_FORCE_INLINE vNx4accshort_t mli_math_add_accus (vNx4accshort_t L, vNx4accshort_t R) {
110+ #if (__Xvec_guard_bit_option == 0)
111+ vNx4short_t L_short = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(L);
112+ vNx4short_t R_short = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(R);
113+
114+ vNx4short_t res = mli_math_add_fx<vNx4short_t>(L_short, R_short);
115+
116+ return mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(res, (vNx4short_t)0 );
117+ #else
118+ return mli_math_add (L, R);
119+ #endif
120+ }
121+
97122template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
98123static inline void rnn_dense_op (
99124 const MLI_PTR (io_T) __restrict * inputs,
@@ -107,14 +132,16 @@ static inline void rnn_dense_op(
107132 quant_T * in_to_out_quant_params,
108133 const io_T val_min_limit,
109134 const io_T val_max_limit) {
110-
135+ typedef typename std::conditional<std::is_same<acc_T, vNx4accshort_t>::value, vNx4int_t, acc_T>::type ir_T;
111136 int num_lanes = get_number_lanes<acc_T>();
137+
112138 for (int o_idx = 0 ; o_idx < out_elements; o_idx += num_lanes) {
113139 int remaining_ch = out_elements - o_idx;
114140 int current_chs = MIN (remaining_ch, num_lanes); // number of channels computed in this loop iteration
115141
116- acc_T accu = mli_math_mul_fx<io_T, acc_T>(0 , 0 );
117- acc_T prev_step = mli_math_mul_fx<io_T, acc_T>(0 , 0 );
142+ acc_T accu = mli_prv_init_accu<acc_T>();
143+ ir_T acc_ir = mli_prv_init_accu<ir_T>();
144+ ir_T acc_res_ir = mli_prv_init_accu<ir_T>();
118145
119146 auto output_params = adjust_quant_params_v (&in_to_out_quant_params[0 ], 0 );
120147 accu = mli::krn::bias_additive (&bias[o_idx], accu, &output_params, /* add_preshift_rnd */ false );
@@ -124,20 +151,18 @@ static inline void rnn_dense_op(
124151 output_params = adjust_quant_params_v (&in_to_out_quant_params[idx], 0 );
125152 accu = dotprod_inputzp_1D_v (inputs[idx], &weights[idx][o_idx], accu, in_elements[idx],
126153 1 , w_ch_out_mem_strides[idx], &in_to_out_quant_params[idx]);
127- accu = mli_math_add (accu, prev_step);
128-
129- if (inputs_num - idx != 1 ) {
130- mli::krn::ref::adjust_quant_params (&in_to_out_quant_params[idx], o_idx);
131- prev_step = mli::krn::ir_rnn_result_requantize (accu, &in_to_out_quant_params[idx],
132- &in_to_out_quant_params[idx + 1 ], /* krn_idx= */ 0 );
133- accu = mli_math_mul_fx<io_T, acc_T>(0 , 0 );
134- } else {
135- // Cast result to output type with scaling
136- mli::krn::result_cast_relu_store_v (&out[o_idx], accu, &output_params,
137- val_min_limit, val_max_limit, current_chs, /* add_preshift_rnd */ true );
138- }
154+
155+ /* TODO: can be optimized using adjust_quant_params_v, and also optimize ir_rnn_result_requantize function */
156+ mli::krn::ref::adjust_quant_params (&in_to_out_quant_params[idx], o_idx);
157+ acc_ir = mli::krn::ir_rnn_result_requantize<acc_T, ir_T>(accu, &in_to_out_quant_params[idx]);
158+
159+ acc_res_ir = mli_math_add_accus (acc_res_ir, acc_ir);
160+ accu = mli_prv_init_accu<acc_T>();
139161 }
140162
163+ // Cast result to output type with scaling
164+ mli::krn::ir_result_cast_relu_store_v (&out[o_idx], acc_res_ir, &output_params,
165+ val_min_limit, val_max_limit, current_chs);
141166 }
142167}
143168
0 commit comments