Skip to content

Commit f26b02d

Browse files
AhmedHussein535JaccovG
authored andcommitted
fix negative_shift in eltwise_mul
1 parent cd6da4d commit f26b02d

File tree

8 files changed

+165
-89
lines changed

8 files changed

+165
-89
lines changed

lib/src/kernels/eltwise/impl/mli_krn_eltwise_ref.h

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
#include "mli_prv_dsp.h"
1919
#include "mli_math.h"
2020

21-
#define INT64_TO_INT16 48
2221
#define INT32_TO_INT16 16
22+
#define IN_SCALE_SHIFT 32
23+
#define MUL_MAX_SHIFT 31
2324

2425
namespace mli {
2526
namespace krn {
@@ -96,6 +97,42 @@ out_T eltwise_perform_operation(
9697
return res;
9798
}
9899

100+
101+
template <>
102+
MLI_FORCE_INLINE int8_t eltwise_perform_operation<int8_t, int8_t, ELTWISE_MUL, true>(
103+
const int8_t op1,
104+
const int8_t op2,
105+
const int16_t in_offset1,
106+
const int16_t in_offset2,
107+
const int16_t out_offset,
108+
const int16_t scale_factor1,
109+
const int16_t scale_factor2,
110+
const int pre_op_shift1,
111+
const int pre_op_shift2,
112+
const int post_op_shift) {
113+
int8_t res = 0;
114+
int32_t acc;
115+
int32_t input1, input2;
116+
117+
input1 = mli_math_sub_fx<int16_t> (op1, in_offset1);
118+
input2 = mli_math_sub_fx<int16_t> (op2, in_offset2);
119+
120+
acc = mli_math_mul_fx<int32_t, int64_t> (input1, input2);
121+
const int headroom = 3;
122+
const int acc_len = 32;
123+
const int out_len = 8;
124+
const int target_out_shift = acc_len - out_len - headroom;
125+
const int preshift = mli_math_min_fx(mli_math_max_fx(post_op_shift - target_out_shift, 0), headroom);
126+
const int shift = post_op_shift - preshift;
127+
int16_t acc_result = mli_math_cast_fx<int32_t, int16_t>(acc, preshift);
128+
int32_t acc_scaled = mli_math_mul_fx<int16_t, int32_t> (acc_result, scale_factor1);
129+
int16_t tmp16 = mli_math_cast_fx<int32_t, int16_t>(acc_scaled, shift);
130+
tmp16 = mli_math_add_fx<int16_t>(tmp16, out_offset);
131+
res = mli_math_cast_fx<int16_t, int8_t>(tmp16, 0);
132+
133+
return res;
134+
}
135+
99136
template <typename io_T, mli_eltwise_type func_type, bool convert>
100137
void eltwise_innerloop(
101138
const MLI_PTR(io_T) __restrict op1_ptr,
@@ -240,18 +277,14 @@ void eltwise_prepare_and_run(
240277
scale16_2 = scale16_1;
241278
post_op_shift -= shift;
242279
} else if (func_type == ELTWISE_MUL) {
243-
in_scale_fx1 = mli_math_asr_rnd_fx<int32_t>(scale_1,
244-
(int32_t) shift1 - frac_bits_fx16);
245-
in_scale_fx2 = mli_math_asr_rnd_fx<int32_t>(scale_2,
246-
(int32_t) shift2 - frac_bits_fx16);
247-
out_scale_fx = mli_math_asr_rnd_fx<int32_t>(scale_out,
248-
(int32_t) shift_out - frac_bits_fx16);
249-
int64_t scale_factor = mli_math_asr_rnd_fx<int64_t>(in_scale_fx1, -INT32_TO_INT16);
250-
scale_factor = (scale_factor / out_scale_fx) * in_scale_fx2;
251-
post_op_shift = INT32_TO_INT16 + frac_bits_fx16;
252-
int norm = (scale_factor != 0) ? mli_math_norm_fx<int64_t, int>(scale_factor) : 0;
253-
int shift = MAX((INT64_TO_INT16 - norm), 0);
254-
scale16_1 = mli_math_cast_fx<int64_t, int16_t>(scale_factor, shift);
280+
int64_t scale_factor = mli_math_asl_fx<int64_t>(scale_1, IN_SCALE_SHIFT);
281+
scale_factor = ((scale_factor * scale_2) / scale_out);
282+
post_op_shift = IN_SCALE_SHIFT + shift1 + shift2 - shift_out;
283+
int shift;
284+
scale16_1 = mli_math_norm_cast_fx<int64_t, int16_t>(scale_factor, &shift);
285+
post_op_shift -= shift;
286+
shift = MAX(post_op_shift - MUL_MAX_SHIFT, 0) + MIN(MUL_MAX_SHIFT + post_op_shift, 0);
287+
scale16_1 = mli_math_asr_rnd_fx<int16_t>(scale16_1, shift);
255288
post_op_shift -= shift;
256289
} else {
257290
in_scale_fx1 = mli_math_asr_rnd_fx<int32_t>(scale_1,

lib/src/kernels/eltwise/impl/mli_krn_eltwise_vdsp.h

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@
1818
#include "arc_vector.h"
1919

2020
const int unroll_factor[2][5] = {
21-
{
22-
/* ELTWISE_ADD_NO_CONVERT = */ 1,
23-
/* ELTWISE_SUB_NO_CONVERT = */ 1,
24-
/* ELTWISE_MUL_NO_CONVERT = */ 4,
25-
/* ELTWISE_MAX_NO_CONVERT = */ 4,
26-
/* ELTWISE_MIN_NO_CONVERT = */ 4
27-
} ,
28-
{
29-
/* ELTWISE_ADD_CONVERT = */ 1,
30-
/* ELTWISE_SUB_CONVERT = */ 1,
31-
/* ELTWISE_MUL_CONVERT = */ 3,
32-
/* ELTWISE_MAX_CONVERT = */ 3,
33-
/* ELTWISE_MIN_CONVERT = */ 3
34-
}
21+
{
22+
/* ELTWISE_ADD_NO_CONVERT = */ 1,
23+
/* ELTWISE_SUB_NO_CONVERT = */ 1,
24+
/* ELTWISE_MUL_NO_CONVERT = */ 4,
25+
/* ELTWISE_MAX_NO_CONVERT = */ 4,
26+
/* ELTWISE_MIN_NO_CONVERT = */ 4
27+
} ,
28+
{
29+
/* ELTWISE_ADD_CONVERT = */ 1,
30+
/* ELTWISE_SUB_CONVERT = */ 1,
31+
/* ELTWISE_MUL_CONVERT = */ 4,
32+
/* ELTWISE_MAX_CONVERT = */ 3,
33+
/* ELTWISE_MIN_CONVERT = */ 3
34+
}
3535
};
3636

3737
namespace mli {
@@ -296,51 +296,61 @@ MLI_FORCE_INLINE vNx4char_t eltwise_perform_operation<vNx4char_t, vNx4char_t, EL
296296
const int pre_op_shift1,
297297
const int pre_op_shift2,
298298
const int post_op_shift) {
299-
MLI_ASSERT(post_op_shift > 3);
300299
vNx4char_t res;
300+
const int headroom = 3;
301+
const int hi_comp = 16;
302+
const int acc_len = 32;
303+
const int out_len = 8;
304+
const int target_out_shift = acc_len - out_len - headroom;
305+
const int preshift = mli_math_min_fx(mli_math_max_fx(post_op_shift - target_out_shift, 0), headroom);
306+
const int shift = post_op_shift - hi_comp - preshift;
307+
const int shift_left = mli_math_max_fx(1 - shift, 0);
308+
const int shift_right = mli_math_max_fx(shift, 1);
301309

302310
#if defined(__Xvec_guard_bit_option) && __Xvec_guard_bit_option != 0
303311
/*
304312
* res = ((op1 - in_offset1) * (op2 - in_offset2) * scale_factor1 >> post_op_shift) + out_offset
305-
* acc_init = in_offset1 * in_offset2 * scale_factor + out_offset << post_op_shift
306-
* term1 = op1 * op2 * scale_factor1 // 31 bit
307-
* term2 = - op2 * in_offset1 * scale_factor1 // 32 bit
308-
* term3 = - op1 * in_offset2 * scale_factor1 // 32 bit
313+
* acc_init = in_offset1 * in_offset2
314+
* term1 = op1 * op2 * scale_factor1
315+
* term2 = - op2 * in_offset1 * scale_factor1
316+
* term3 = - op1 * in_offset2 * scale_factor1
317+
* acc = (term1 + term2 + term3) * scale_factor >> post_op_shift + out_offset
318+
*
309319
*/
320+
310321
int16_t acc_init = in_offset1 * in_offset2;
311322
vNx4accshort_t acc16 = mli_math_init_accu<int16_t, vNx4accshort_t>(acc_init);
312323
acc16 = mli_math_mac_fx(acc16, op1, op2);
313324
acc16 = mli_math_msub_fx(acc16, op2, (vNx4char_t)(int8_t)in_offset1);
314325
acc16 = mli_math_msub_fx(acc16, op1, (vNx4char_t)(int8_t)in_offset2);
315-
vNx4short_t vacc16 = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(acc16);
316-
vNx4int_t acc = mli_math_mul_fx<vNx4short_t, vNx4int_t>(vacc16, scale_factor1);
317-
acc = mli_math_asr_rnd_fx(acc, post_op_shift);
318-
acc = mli_math_add_fx(acc, (vNx4int_t) out_offset);
319-
res = mli_math_cast_fx<vNx4int_t, vNx4char_t>(acc);
320-
#else
326+
321327
/*
322-
* Each operand is 9 bit. The first multiplier output is 18 bit. After scaling with positive 15 bit scale_factor,
323-
* The second multiplier output is 32 bits. A headroom of 3 is sufficient to add the offset, round and compensate.
324-
*
325-
* Note: Minimum shift value is 15
326-
*/
327-
328-
const int preshift_sf = 3;
329-
const int mask = (1 << preshift_sf) - 1;
328+
* If we preshift we can continue the operations in 16 bits. Only 8 bits are needs from the
329+
* mul_hi output. with headroom of 3 bits.
330+
*/
331+
332+
vNx4short_t vacc16 = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(acc16, preshift);
333+
334+
335+
#else
336+
330337
vNx4short_t op1_offset = to_vNx4short_t(op1) - in_offset1;
331338
vNx4short_t op2_offset = to_vNx4short_t(op2) - in_offset2;
332-
vNx4int_t temp1 = mli_math_mul_fx<vNx4short_t, vNx4int_t>(op1_offset, op2_offset);
333-
vNx4int_t temp2 = (scale_factor1 & mask);
334-
vNx4int_t offset = out_offset;
335-
vNx4accint_t acc = mli_math_mul_fx_low(temp1, temp2);
336-
acc = mli_math_asr_fx(acc, preshift_sf);
337-
temp2 = (scale_factor1 >> preshift_sf);
338-
acc = mli_math_mac_fx_low(acc, temp1, temp2);
339-
acc = mli_math_asr_rnd_fx(acc, post_op_shift - preshift_sf);
340-
acc = mli_math_add(acc, offset);
341-
res = mli_math_acc_cast_fx<vNx4char_t, vNx4accint_t>(acc);
339+
vNx4int_t acc32 = mli_math_mul_fx<vNx4short_t, vNx4int_t>(op1_offset, op2_offset);
340+
341+
/*
342+
* If we preshift we can continue the operations in 16 bits. Only 8 bits are needs from the
343+
* mul_hi output. with headroom of 3 bits.
344+
*/
345+
346+
vNx4short_t vacc16 = mli_math_cast_fx<vNx4int_t, vNx4short_t>(acc32, preshift);
342347
#endif
343348

349+
vacc16 = mli_math_asl_fx(vacc16, shift_left);
350+
vNx4short_t accu_scaled = mli_math_mul_fx_high(vacc16, scale_factor1);
351+
accu_scaled = mli_math_asr_rnd_fx(accu_scaled, shift_right);
352+
accu_scaled = mli_math_add_fx(accu_scaled, (vNx4short_t) out_offset);
353+
res = mli_math_cast_fx<vNx4short_t, vNx4char_t>(accu_scaled);
344354

345355
return res;
346356
}
@@ -549,6 +559,7 @@ void eltwise_innerloop(
549559
idx_out += num_lanes;
550560
}
551561
}
562+
552563
template<>
553564
MLI_FORCE_INLINE void eltwise_innerloop<int16_t, ELTWISE_MAX, false>(
554565
const MLI_PTR(int16_t) __restrict op1_ptr,

lib/src/kernels/eltwise/mli_krn_eltwise_decl.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void eltwise_op_basic(
8080
const int out_offset);
8181

8282
template <typename in_T, typename out_T, mli_eltwise_type func_type, bool convert>
83-
MLI_FORCE_INLINE out_T eltwise_perform_operation(
83+
out_T eltwise_perform_operation(
8484
const in_T op1,
8585
const in_T op2,
8686
const int16_t in_offset1,
@@ -92,6 +92,19 @@ MLI_FORCE_INLINE out_T eltwise_perform_operation(
9292
const int pre_op_shift2,
9393
const int post_op_shift);
9494

95+
template <>
96+
MLI_FORCE_INLINE int8_t eltwise_perform_operation <int8_t, int8_t, ELTWISE_MUL, true>(
97+
const int8_t op1,
98+
const int8_t op2,
99+
const int16_t in_offset1,
100+
const int16_t in_offset2,
101+
const int16_t out_offset,
102+
const int16_t scale_factor1,
103+
const int16_t scale_factor2,
104+
const int pre_op_shift1,
105+
const int pre_op_shift2,
106+
const int post_op_shift);
107+
95108
template <typename io_T, mli_eltwise_type func_type, bool convert>
96109
void eltwise_innerloop(
97110
const MLI_PTR(io_T) __restrict op1_ptr,

lib/src/pal/dsp/mli_math.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ MLI_FORCE_INLINE T mli_math_asl_fx(T x, int nbits);
4343
template <typename T>
4444
MLI_FORCE_INLINE T mli_math_asr_fx(T x, int nbits);
4545

46+
template <>
47+
MLI_FORCE_INLINE int64_t mli_math_asl_fx(int64_t x, int nbits) {
48+
return fx_asl_q63(x, nbits);
49+
}
50+
4651
template <>
4752
MLI_FORCE_INLINE int32_t mli_math_asl_fx(int32_t x, int nbits) {
4853
return fx_asl_q31(x, nbits);

lib/src/pal/vdsp/mli_math.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,20 @@ MLI_FORCE_INLINE vNx4char_t mli_math_acc_cast_fx(vNx4accint_t acc, int shift_rig
13511351
return to_vNx4char_t(accu_result);
13521352
}
13531353

1354+
template<>
1355+
MLI_FORCE_INLINE vNx4short_t mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t,/*round = */ false>(
1356+
vNx4accshort_t acc, int shift_right) {
1357+
MLI_EXTRA_ASSERT(shift_right >= 0);
1358+
1359+
int ctrlword = SAT|SIGNED|TARGET_SZ_16|SHIFT(shift_right);
1360+
vNx4short_t accu_result;
1361+
accu_result.lo = to_vNx2short_t(vvconvert(__vacc_lo(acc), ctrlword));
1362+
accu_result.hi = to_vNx2short_t(vvconvert(__vacc_hi(acc), ctrlword));
1363+
1364+
return accu_result;
1365+
}
1366+
1367+
13541368
template<>
13551369
MLI_FORCE_INLINE vNx4char_t mli_math_acc_cast_fx<vNx4char_t, vNx4accint_t,/*round = */ false>(
13561370
vNx4accint_t acc, int shift_right) {

user_tests/tests/mli_krn_eltwise/tests_mli_krn_eltwise.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ const crc32_calc test_1_chksum_sa8{ 0xd48163e
5656
test_4_chksum_sa8{ 0xF22D7321 },
5757
test_5_chksum_sa8{ 0x9A14384C },
5858
test_6_chksum_fx16{ 0xfc026def }, test_6_chksum_sa8{ 0x3a54561 },
59-
test_7_chksum_fx16{ 0x488ed527 }, test_7_chksum_sa8{ 0xDA50B98A },
59+
test_7_chksum_fx16{ 0x488ed527 }, test_7_chksum_sa8{ 0xD4B7515B },
6060
test_8_chksum_fx16{ 0x68889D84 }, test_8_chksum_sa8{ 0x168B3B32 },
6161
test_9_chksum_fx16{ 0x9417F3D7 }, test_9_chksum_sa8{ 0xA83B910E },
6262
test_10_chksum_fx16{ 0xD728E430 }, test_10_chksum_sa8{ 0xE34DA6B0 },

user_tests/tests/mli_krn_gru_cell/tests_mli_krn_gru_cell.cc

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,29 +60,28 @@ struct gru_cell_test_operands {
6060
#if defined(CRC_RM_UP)
6161
const crc32_calc test_1_chksum_fx16{ 0xCA3B3621 }, test_1_chksum_fx16_fx8_fx8{ 0x7C81E8FA }, test_1_chksum_sa8{ 0xBA369AB3 },
6262
test_2_chksum_fx16{ 0xCA3B3621 }, test_2_chksum_fx16_fx8_fx8{ 0x7C81E8FA }, test_2_chksum_sa8{ 0xBA369AB3 },
63-
test_3_chksum_fx16{ 0x0575B7B5 }, test_3_chksum_fx16_fx8_fx8{ 0x3105731C }, test_3_chksum_sa8{ 0xBC580566 },
64-
test_4_chksum_fx16{ 0xA957E40B }, test_4_chksum_fx16_fx8_fx8{ 0x44D14AA8 }, test_4_chksum_sa8{ 0x2F5C16B5 },
65-
test_5_chksum_fx16{ 0xA9D9FC7B }, test_5_chksum_fx16_fx8_fx8{ 0xB08CE82C }, test_5_chksum_sa8{ 0x08D240F4 },
66-
test_6_chksum_fx16{ 0x82B87A3D }, test_6_chksum_fx16_fx8_fx8{ 0x1D12879D }, test_6_chksum_sa8{ 0x921BE561 };
63+
test_3_chksum_fx16{ 0x0575B7B5 }, test_3_chksum_fx16_fx8_fx8{ 0x3105731C }, test_3_chksum_sa8{ 0xD7D30910 },
64+
test_4_chksum_fx16{ 0xA957E40B }, test_4_chksum_fx16_fx8_fx8{ 0x44D14AA8 }, test_4_chksum_sa8{ 0x551350E9 },
65+
test_5_chksum_fx16{ 0xA9D9FC7B }, test_5_chksum_fx16_fx8_fx8{ 0xB08CE82C }, test_5_chksum_sa8{ 0x482C5F79 },
66+
test_6_chksum_fx16{ 0x82B87A3D }, test_6_chksum_fx16_fx8_fx8{ 0x1D12879D }, test_6_chksum_sa8{ 0xBC364FC3 };
6767

6868
#elif defined(CRC_RM_CONVERGENT)
6969
// TODO: remove after fixing mli_math_acc_ashift_fx() and supporting acc40 shift with round
7070
#if defined(__FXAPI__)
7171
const crc32_calc test_1_chksum_fx16{ 0xE5852A3E }, test_1_chksum_fx16_fx8_fx8{ 0xF979CA35 }, test_1_chksum_sa8{ 0xBA369AB3 },
7272
test_2_chksum_fx16{ 0xE5852A3E }, test_2_chksum_fx16_fx8_fx8{ 0xF979CA35 }, test_2_chksum_sa8{ 0xBA369AB3 },
73-
test_3_chksum_fx16{ 0x6F7E4D9B }, test_3_chksum_fx16_fx8_fx8{ 0xE47B56B4 }, test_3_chksum_sa8{ 0xBC580566 },
74-
test_4_chksum_fx16{ 0x3A84CF63 }, test_4_chksum_fx16_fx8_fx8{ 0x202E9565 }, test_4_chksum_sa8{ 0x2F5C16B5 },
75-
test_5_chksum_fx16{ 0xD81EFB70 }, test_5_chksum_fx16_fx8_fx8{ 0x7C0CE29B }, test_5_chksum_sa8{ 0x08D240F4 },
76-
test_6_chksum_fx16{ 0x31D77812 }, test_6_chksum_fx16_fx8_fx8{ 0x1D12879D }, test_6_chksum_sa8{ 0x921BE561 };
73+
test_3_chksum_fx16{ 0x6F7E4D9B }, test_3_chksum_fx16_fx8_fx8{ 0xE47B56B4 }, test_3_chksum_sa8{ 0xB0B3B302 },
74+
test_4_chksum_fx16{ 0x3A84CF63 }, test_4_chksum_fx16_fx8_fx8{ 0x202E9565 }, test_4_chksum_sa8{ 0xE0C80764 },
75+
test_5_chksum_fx16{ 0xD81EFB70 }, test_5_chksum_fx16_fx8_fx8{ 0x7C0CE29B }, test_5_chksum_sa8{ 0xB5805E4A },
76+
test_6_chksum_fx16{ 0x31D77812 }, test_6_chksum_fx16_fx8_fx8{ 0x1D12879D }, test_6_chksum_sa8{ 0xBC364FC3 };
7777
#else
7878
const crc32_calc test_1_chksum_fx16{ 0xCA3B3621 }, test_1_chksum_fx16_fx8_fx8{ 0xF979CA35 }, test_1_chksum_sa8{ 0xBA369AB3 },
7979
test_2_chksum_fx16{ 0xCA3B3621 }, test_2_chksum_fx16_fx8_fx8{ 0xF979CA35 }, test_2_chksum_sa8{ 0xBA369AB3 },
80-
test_3_chksum_fx16{ 0x0575B7B5 }, test_3_chksum_fx16_fx8_fx8{ 0xE47B56B4 }, test_3_chksum_sa8{ 0xBC580566 },
81-
test_4_chksum_fx16{ 0x4DEDC850 }, test_4_chksum_fx16_fx8_fx8{ 0x202E9565 }, test_4_chksum_sa8{ 0x2F5C16B5 },
82-
test_5_chksum_fx16{ 0xA9D9FC7B }, test_5_chksum_fx16_fx8_fx8{ 0x7C0CE29B }, test_5_chksum_sa8{ 0x08D240F4 },
83-
test_6_chksum_fx16{ 0x82B87A3D }, test_6_chksum_fx16_fx8_fx8{ 0x1D12879D }, test_6_chksum_sa8{ 0x921BE561 };
80+
test_3_chksum_fx16{ 0x0575B7B5 }, test_3_chksum_fx16_fx8_fx8{ 0xE47B56B4 }, test_3_chksum_sa8{ 0xB0B3B302 },
81+
test_4_chksum_fx16{ 0x4DEDC850 }, test_4_chksum_fx16_fx8_fx8{ 0x202E9565 }, test_4_chksum_sa8{ 0xE0C80764 },
82+
test_5_chksum_fx16{ 0xA9D9FC7B }, test_5_chksum_fx16_fx8_fx8{ 0x7C0CE29B }, test_5_chksum_sa8{ 0xB5805E4A },
83+
test_6_chksum_fx16{ 0x82B87A3D }, test_6_chksum_fx16_fx8_fx8{ 0x1D12879D }, test_6_chksum_sa8{ 0xBC364FC3 };
8484
#endif
85-
8685
#else // Not defined CRC_*
8786
const crc32_calc test_1_chksum_fx16, test_1_chksum_fx16_fx8_fx8, test_1_chksum_sa8,
8887
test_2_chksum_fx16, test_2_chksum_fx16_fx8_fx8, test_2_chksum_sa8,

0 commit comments

Comments
 (0)