Skip to content

Commit 8edb47f

Browse files
AhmedHussein535dzakhar
authored andcommitted
Increase LSTM accuracy
1 parent 5b65732 commit 8edb47f

File tree

6 files changed

+125
-157
lines changed

6 files changed

+125
-157
lines changed

lib/src/kernels/common/impl/mli_krn_lstm_cell_ref.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
8484
ir_asym_params.sa.scale_frac_bits.capacity = 0;
8585
ir_tensor.el_params = ir_asym_params;
8686
} else {
87-
// 1sign and 3 integer bits for TANH/SIGM input is enough
88-
ir_tensor.el_params.fx.frac_bits = (sizeof(io_T) * 8) - 1 - 3;
87+
// [-32, 32] is enough for TANH/SIGM input
88+
ir_tensor.el_params.fx.frac_bits = 10;
8989
ir_tensor.el_params.fx.frac_bits = MIN(ir_tensor.el_params.fx.frac_bits, in->el_params.fx.frac_bits + weights_in->el_params.fx.frac_bits);
9090
}
9191

lib/src/kernels/common/impl/mli_krn_lstm_cell_vdsp.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
8686
ir_asym_params.sa.scale_frac_bits.capacity = 0;
8787
ir_tensor.el_params = ir_asym_params;
8888
} else {
89-
// 1sign and 3 integer bits for TANH/SIGM input is enough
90-
ir_tensor.el_params.fx.frac_bits = (sizeof(io_T) * 8) - 1 - 3;
89+
// [-32, 32] is enough for TANH/SIGM input
90+
ir_tensor.el_params.fx.frac_bits = 10;
9191
ir_tensor.el_params.fx.frac_bits = MIN(ir_tensor.el_params.fx.frac_bits, in->el_params.fx.frac_bits + weights_in->el_params.fx.frac_bits);
9292
}
9393

lib/src/pal/dsp/mli_math.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ template <> MLI_FORCE_INLINE v2q15_t mli_math_sub_fx(v2q15_t L, v2q15_t R) {
186186

187187
// Maximum of two fx operands
188188
//========================================================================
189-
template < typename io_T >
189+
template < typename io_T >
190190
MLI_FORCE_INLINE io_T mli_math_max_fx(io_T L, io_T R) {
191191
return MAX(L, R);
192192
}
@@ -196,19 +196,19 @@ MLI_FORCE_INLINE l_T mli_math_max_fx(l_T L, r_T R) {
196196
return MAX(L, R);
197197
}
198198

199-
template <>
199+
template <>
200200
MLI_FORCE_INLINE v2q15_t mli_math_max_fx(v2q15_t L, v2q15_t R) {
201201
return fx_max_v2q15(L, R);
202202
}
203203

204-
template <typename l_T, typename r_T>
204+
template <typename l_T, typename r_T>
205205
MLI_FORCE_INLINE v2q15_t mli_math_max_fx(v2q15_t L, r_T R) {
206206
return fx_max_v2q15(L, fx_replic_v2q15(R));
207207
}
208208

209209
// Minimum of two fx operands
210210
//========================================================================
211-
template < typename io_T >
211+
template < typename io_T >
212212
MLI_FORCE_INLINE io_T mli_math_min_fx(io_T L, io_T R) {
213213
return MIN(L, R);
214214
}
@@ -218,12 +218,12 @@ MLI_FORCE_INLINE l_T mli_math_min_fx(l_T L, r_T R) {
218218
return (L < R) ? L : R;
219219
}
220220

221-
template <>
221+
template <>
222222
MLI_FORCE_INLINE v2q15_t mli_math_min_fx(v2q15_t L, v2q15_t R) {
223223
return fx_min_v2q15(L, R);
224224
}
225225

226-
template <typename l_T, typename r_T>
226+
template <typename l_T, typename r_T>
227227
MLI_FORCE_INLINE v2q15_t mli_math_min_fx(v2q15_t L, r_T R) {
228228
return fx_min_v2q15(L, fx_replic_v2q15(R));
229229
}
@@ -327,6 +327,11 @@ template <> MLI_FORCE_INLINE mli_acc32_t mli_math_acc_ashift_fx(mli_acc32_t acc,
327327
}
328328

329329
template <> MLI_FORCE_INLINE mli_acc40_t mli_math_acc_ashift_fx(mli_acc40_t acc, int shift_right) {
330+
if (shift_right > 0) {
331+
mli_acc40_t rnd = {((1ll << shift_right) >> 1)};
332+
acc = fx_add_a40(acc, rnd);
333+
}
334+
330335
return fx_asr_a40(acc, shift_right);
331336
}
332337

@@ -386,7 +391,7 @@ template < typename in_T > MLI_FORCE_INLINE void *mli_math_cast_scalar_to_ptr_fx
386391

387392
// Comparators
388393
//========================================================================
389-
template < typename io_T >
394+
template < typename io_T >
390395
static MLI_FORCE_INLINE bool mli_prv_less_than_1(io_T value, uint8_t frac_bits) {
391396
if (frac_bits >= sizeof(io_T) * 8 - 1)
392397
return true;

user_tests/tests/mli_krn_gru_cell/tests_mli_krn_gru_cell.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,6 @@ const crc32_calc test_1_chksum_fx16{ 0x93713917 }, test_1_chksum_fx16_fx8_fx8{ 0
6868
test_8_chksum_fx16{ 0xDBDD80AD }, test_8_chksum_fx16_fx8_fx8{ 0x5D935ADE }, test_8_chksum_sa8{ 0x71E73A61 };
6969

7070
#elif defined(CRC_RM_CONVERGENT)
71-
// TODO: remove after fixing mli_math_acc_ashift_fx() and supporting acc40 shift with round
72-
#if defined(__FXAPI__)
73-
const crc32_calc test_1_chksum_fx16{ 0x898EF9AC }, test_1_chksum_fx16_fx8_fx8{ 0xF3E45489 }, test_1_chksum_sa8{ 0x605D7927 },
74-
test_2_chksum_fx16{ 0x898EF9AC }, test_2_chksum_fx16_fx8_fx8{ 0xF3E45489 }, test_2_chksum_sa8{ 0x605D7927 },
75-
test_3_chksum_fx16{ 0xE14A4F30 }, test_3_chksum_fx16_fx8_fx8{ 0x0D9F97BB }, test_3_chksum_sa8{ 0x6A03698A },
76-
test_4_chksum_fx16{ 0xEBCB8726 }, test_4_chksum_fx16_fx8_fx8{ 0xBA61FDE2 }, test_4_chksum_sa8{ 0x3F8041AD },
77-
test_5_chksum_fx16{ 0x4E35CC3A }, test_5_chksum_fx16_fx8_fx8{ 0x63209ADF }, test_5_chksum_sa8{ 0xE36EB137 },
78-
test_6_chksum_fx16{ 0x44B4042C }, test_6_chksum_fx16_fx8_fx8{ 0xD4DEF086 }, test_6_chksum_sa8{ 0xB6ED9910 },
79-
test_7_chksum_fx16{ 0x697A5BA9 }, test_7_chksum_fx16_fx8_fx8{ 0x60AEE5D5 }, test_7_chksum_sa8{ 0xD54F47E2 },
80-
test_8_chksum_fx16{ 0xBDDA2972 }, test_8_chksum_fx16_fx8_fx8{ 0x1B477499 }, test_8_chksum_sa8{ 0x427B2A3F };
81-
#else
8271
const crc32_calc test_1_chksum_fx16{ 0x93713917 }, test_1_chksum_fx16_fx8_fx8{ 0xF3E45489 }, test_1_chksum_sa8{ 0x605D7927 },
8372
test_2_chksum_fx16{ 0x93713917 }, test_2_chksum_fx16_fx8_fx8{ 0xF3E45489 }, test_2_chksum_sa8{ 0x605D7927 },
8473
test_3_chksum_fx16{ 0xEA93E0FF }, test_3_chksum_fx16_fx8_fx8{ 0x0D9F97BB }, test_3_chksum_sa8{ 0x6A03698A },
@@ -87,7 +76,6 @@ const crc32_calc test_1_chksum_fx16{ 0x93713917 }, test_1_chksum_fx16_fx8_fx8{ 0
8776
test_6_chksum_fx16{ 0x5F51D618 }, test_6_chksum_fx16_fx8_fx8{ 0xD4DEF086 }, test_6_chksum_sa8{ 0xB6ED9910 },
8877
test_7_chksum_fx16{ 0xF35521BE }, test_7_chksum_fx16_fx8_fx8{ 0x60AEE5D5 }, test_7_chksum_sa8{ 0xD54F47E2 },
8978
test_8_chksum_fx16{ 0xDBDD80AD }, test_8_chksum_fx16_fx8_fx8{ 0x1B477499 }, test_8_chksum_sa8{ 0x427B2A3F };
90-
#endif
9179
#else // Not defined CRC_*
9280
const crc32_calc test_1_chksum_fx16, test_1_chksum_fx16_fx8_fx8, test_1_chksum_sa8,
9381
test_2_chksum_fx16, test_2_chksum_fx16_fx8_fx8, test_2_chksum_sa8,

0 commit comments

Comments
 (0)