Skip to content

Commit 46a8b1f

Browse files
AhmedHussein535JaccovG
authored andcommitted
fix lstm requant
1 parent 791c39c commit 46a8b1f

File tree

5 files changed

+45
-38
lines changed

5 files changed

+45
-38
lines changed

lib/src/bricks/impl/mli_krn_rnn_dense_op_ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ static inline void rnn_dense_op(
140140
/* row_step= */ 1, /* ch_step= */ 1);
141141
accu = mli_math_add_fx(accu, other_additives[idx]);
142142

143-
acc_ir = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx]);
143+
acc_ir = mli::krn::ir_rnn_result_requantize<acc_T>(accu, &in_to_out_quant_params[idx]);
144144
acc_res_ir = mli_math_add_fx(acc_res_ir, acc_ir);
145145
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
146146
}

lib/src/bricks/impl/mli_krn_rnn_dense_op_vdsp.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ static inline void rnn_dense_op_stacked(
9494
dense_out_ptr -= gates_num * out_elements;
9595
}
9696

97+
MLI_FORCE_INLINE vNx4int_t mli_math_add_accus(vNx4int_t L, vNx4int_t R) {
98+
return mli_math_add_fx(L, R);
99+
}
100+
97101
MLI_FORCE_INLINE vNx2accint_t mli_math_add_accus(vNx2accint_t L, vNx2accint_t R) {
98102
return mli_math_add(L, R);
99103
}
@@ -128,16 +132,16 @@ static inline void rnn_dense_op(
128132
quant_T * in_to_out_quant_params,
129133
const io_T val_min_limit,
130134
const io_T val_max_limit) {
131-
135+
typedef typename std::conditional<std::is_same<acc_T, vNx4accshort_t>::value, vNx4int_t, acc_T>::type ir_T;
132136
int num_lanes = get_number_lanes<acc_T>();
133137

134138
for (int o_idx = 0; o_idx < out_elements; o_idx += num_lanes) {
135139
int remaining_ch = out_elements - o_idx;
136140
int current_chs = MIN(remaining_ch, num_lanes); // number of channels computed in this loop iteration
137141

138142
acc_T accu = mli_prv_init_accu<acc_T>();
139-
acc_T acc_ir = mli_prv_init_accu<acc_T>();
140-
acc_T acc_res_ir = mli_prv_init_accu<acc_T>();
143+
ir_T acc_ir = mli_prv_init_accu<ir_T>();
144+
ir_T acc_res_ir = mli_prv_init_accu<ir_T>();
141145

142146
auto output_params = adjust_quant_params_v(&in_to_out_quant_params[0], 0);
143147
accu = mli::krn::bias_additive(&bias[o_idx], accu, &output_params, /* add_preshift_rnd */ false);
@@ -150,7 +154,7 @@ static inline void rnn_dense_op(
150154

151155
/* TODO: can be optimized using adjust_quant_params_v, and also optimize ir_rnn_result_requantize function */
152156
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], o_idx);
153-
acc_ir = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx]);
157+
acc_ir = mli::krn::ir_rnn_result_requantize<acc_T, ir_T>(accu, &in_to_out_quant_params[idx]);
154158

155159
acc_res_ir = mli_math_add_accus(acc_res_ir, acc_ir);
156160
accu = mli_prv_init_accu<acc_T>();

lib/src/bricks/impl/mli_prv_quant_vdsp.h

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,25 @@ MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
503503
mli_prv_store_n_samples(o_ptr, out, num);
504504
}
505505

506+
template <>
507+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
508+
MLI_CONV_OUT_PTR(int8_t) __restrict o_ptr,
509+
vNx4int_t acc,
510+
const s8asym_quant_specific_out_params_v* quant_params,
511+
const int16_t val_min_limit,
512+
const int16_t val_max_limit,
513+
int num) {
514+
515+
vNx4short_t accu_scaled = mli_math_cast_fx<vNx4int_t, vNx4short_t>(acc);
516+
accu_scaled = mli_math_add_fx<vNx4short_t>(accu_scaled, quant_params->out_offset);
517+
518+
accu_scaled = mli_math_min_fx(accu_scaled, val_max_limit);
519+
accu_scaled = mli_math_max_fx(accu_scaled, val_min_limit);
520+
521+
vNx4char_t out = to_vNx4char_t(accu_scaled);
522+
mli_prv_store_n_samples(o_ptr, out, num);
523+
}
524+
506525
template <>
507526
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
508527
MLI_CONV_OUT_PTR(int16_t) __restrict o_ptr,
@@ -537,19 +556,19 @@ MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
537556
mli_prv_store_n_samples(o_ptr, out, num);
538557
}
539558

540-
template <typename acc_T>
541-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
559+
template <typename acc_T, typename out_T=acc_T>
560+
MLI_FORCE_INLINE out_T ir_rnn_result_requantize(
542561
const acc_T acc,
543562
const fx_quant_specific_params* params) {
544563
const int in_to_ir_shift = params->out_shift;
545564
int shift_right = mli_math_max_fx(in_to_ir_shift, 0);
546565
int shift_left = mli_math_max_fx(-in_to_ir_shift, 0);
547-
acc_T acc_shifted = mli_math_asl_fx(acc, shift_left);
548-
return mli_math_asr_rnd_fx<acc_T, int>(acc_shifted, shift_right);
566+
out_T acc_shifted = mli_math_asl_fx(acc, shift_left);
567+
return mli_math_asr_rnd_fx<out_T, int>(acc_shifted, shift_right);
549568
}
550569

551570
template <>
552-
MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
571+
MLI_FORCE_INLINE vNx4int_t ir_rnn_result_requantize(
553572
const vNx4accshort_t acc,
554573
const s8asym_quant_specific_params* params) {
555574

@@ -578,28 +597,7 @@ MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
578597
acc_shifted = mli_math_asr_rnd_fx(acc_shifted, shift_right);
579598
acc_shifted = mli_math_asl_fx(acc_shifted, shift_left);
580599

581-
#if (__Xvec_guard_bit_option == 0)
582-
vNx4short_t acc_short = mli_math_cast_fx<vNx4int_t, vNx4short_t>(acc_shifted);
583-
vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0);
584-
#else
585-
vNx4int_t norm;
586-
vNx4short_t acc_short = mli_math_norm_cast_fx</*left_shift*/ false>(acc_shifted , &norm);
587-
588-
constexpr int guard_bits = 8;
589-
vNx4int_t mask = (1 << norm) - 1;
590-
vNx4int_t acc_shifted_low = acc_shifted & mask;
591-
// If the norm is more than the number of guardbits,
592-
// so the masked_acc has to be shifted, since the result is shifted with max shift equals to number of guardbits.
593-
vNx4int_t mask_shift = mli_math_max_fx(norm - guard_bits, 0);
594-
acc_shifted_low = mli_math_asr_fx(acc_shifted_low, mask_shift);
595-
596-
norm = mli_math_min_fx(norm, guard_bits);
597-
vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0);
598-
res = mli_math_asl_fx<vNx4accshort_t, vNx4short_t>(res, to_vNx4short_t(norm));
599-
res = mli_math_add(res, to_vNx4short_t(acc_shifted_low));
600-
#endif
601-
602-
return res;
600+
return acc_shifted;
603601
}
604602

605603
} // namespace vdsp

lib/src/bricks/mli_prv_quant_decl.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,16 +420,17 @@ MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
420420
int num);
421421
#endif
422422

423-
template <typename acc_T, typename quant_T>
424-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
423+
template <typename acc_T, typename out_T, typename quant_T>
424+
MLI_FORCE_INLINE out_T ir_rnn_result_requantize(
425425
const acc_T acc, const quant_T* params);
426-
template <typename acc_T>
427-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
426+
427+
template <typename acc_T, typename out_T>
428+
MLI_FORCE_INLINE out_T ir_rnn_result_requantize(
428429
const acc_T acc, const fx_quant_specific_params* params);
429430

430431
#if defined(__Xvec_width)
431432
template <>
432-
MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
433+
MLI_FORCE_INLINE vNx4int_t ir_rnn_result_requantize(
433434
const vNx4accshort_t acc,
434435
const s8asym_quant_specific_params* params);
435436
#endif

lib/src/pal/vdsp/mli_prv_dsp.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ MLI_FORCE_INLINE vNx4accshort_t mli_prv_init_accu<vNx4accshort_t>() {
286286
return vvcmpy((vNx4char_t)0, (int8_t)0);
287287
}
288288

289+
template<>
290+
MLI_FORCE_INLINE vNx4int_t mli_prv_init_accu<vNx4int_t>() {
291+
return ((vNx4int_t) (0));
292+
}
293+
289294
MLI_FORCE_INLINE vNx4accshort_t mli_prv_init_accu(vNx4char_t l, int8_t r) {
290295
return vvcmpy(l, r);
291296
}
@@ -299,7 +304,6 @@ MLI_FORCE_INLINE vNx2accint_t mli_prv_init_accu(vNx2short_t l, int16_t r) {
299304
return vvcmpy(l, r);
300305
}
301306

302-
303307
template<>
304308
MLI_FORCE_INLINE vNx4accint_t mli_prv_init_accu<vNx4accint_t>() {
305309
vNx4accint_t r;

0 commit comments

Comments
 (0)