Skip to content

Commit 791c39c

Browse files
AhmedHussein535JaccovG
authored andcommitted
optimize eltwise mul preshift
1 parent ad1df2e commit 791c39c

File tree

5 files changed

+20
-16
lines changed

5 files changed

+20
-16
lines changed

lib/src/kernels/eltwise/impl/mli_krn_eltwise_ref.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,12 @@ void eltwise_prepare_and_run(
277277
scale16_2 = scale16_1;
278278
post_op_shift -= shift;
279279
} else if (func_type == ELTWISE_MUL) {
280-
int64_t scale_factor = mli_math_asl_fx<int64_t>(scale_1, IN_SCALE_SHIFT);
281-
scale_factor = ((scale_factor * scale_2) / scale_out);
282-
post_op_shift = IN_SCALE_SHIFT + shift1 + shift2 - shift_out;
283280
int shift;
284-
scale16_1 = mli_math_norm_cast_fx<int64_t, int16_t>(scale_factor, &shift);
281+
scale_factor1 = scale_1 * scale_2;
282+
scale_factor1 = mli_math_norm_cast_fx<int32_t, int32_t>(scale_factor1, &shift);
283+
scale_factor1 = (scale_factor1 / scale_out);
284+
post_op_shift = shift1 + shift2 - shift_out - shift;
285+
scale16_1 = mli_math_norm_cast_fx<int32_t, int16_t>(scale_factor1, &shift);
285286
post_op_shift -= shift;
286287
shift = MAX(post_op_shift - MUL_MAX_SHIFT, 0) + MIN(MUL_MAX_SHIFT + post_op_shift, 0);
287288
scale16_1 = mli_math_asr_rnd_fx<int16_t>(scale16_1, shift);

lib/src/kernels/eltwise/impl/mli_krn_eltwise_vdsp.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ MLI_FORCE_INLINE vNx4char_t eltwise_perform_operation<vNx4char_t, vNx4char_t, EL
319319
*/
320320

321321
int16_t acc_init = in_offset1 * in_offset2;
322+
#ifdef ROUND_UP
323+
acc_init += ((1 << preshift) >> 1); /* rounding half up */
324+
#else
325+
#error Rounding mode not supported
326+
#endif
322327
vNx4accshort_t acc16 = mli_math_init_accu<int16_t, vNx4accshort_t>(acc_init);
323328
acc16 = mli_math_mac_fx(acc16, op1, op2);
324329
acc16 = mli_math_msub_fx(acc16, op2, (vNx4char_t)(int8_t)in_offset1);
@@ -329,9 +334,7 @@ MLI_FORCE_INLINE vNx4char_t eltwise_perform_operation<vNx4char_t, vNx4char_t, EL
329334
* mul_hi output. with headroom of 3 bits.
330335
*/
331336

332-
vNx4short_t vacc16 = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(acc16, preshift);
333-
334-
337+
vNx4short_t vacc16 = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t, false>(acc16, preshift);
335338
#else
336339

337340
vNx4short_t op1_offset = to_vNx4short_t(op1) - in_offset1;

lib/src/pal/dsp/mli_math.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ MLI_FORCE_INLINE int mli_math_norm_fx(mli_acc40_t acc) {
126126
}
127127

128128
template<typename in_T, typename out_T>
129-
MLI_FORCE_INLINE out_T mli_math_norm_cast_fx(in_T val , int *norm_shift) {
130-
int cast_shift = (sizeof(in_T) - sizeof(out_T)) * 8;
131-
int norm = mli_math_norm_fx<in_T, in_T>(val);
129+
MLI_FORCE_INLINE out_T mli_math_norm_cast_fx(in_T val , int32_t *norm_shift) {
130+
int32_t cast_shift = (sizeof(in_T) - sizeof(out_T)) * 8;
131+
int32_t norm = mli_math_norm_fx<in_T, int32_t>(val);
132132
*norm_shift = cast_shift - norm;
133133
return mli_math_cast_fx<in_T, out_T>(val, *norm_shift);
134134
}

lib/src/pal/ref/mli_math.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ MLI_FORCE_INLINE o_T mli_math_norm_fx(T x)
141141
}
142142

143143
template<typename in_T, typename out_T>
144-
MLI_FORCE_INLINE out_T mli_math_norm_cast_fx(in_T val , int *norm_shift) {
145-
int cast_shift = (sizeof(in_T) - sizeof(out_T)) * 8;
146-
int norm = mli_math_norm_fx<in_T, int>(val);
144+
MLI_FORCE_INLINE out_T mli_math_norm_cast_fx(in_T val , int32_t *norm_shift) {
145+
int32_t cast_shift = (sizeof(in_T) - sizeof(out_T)) * 8;
146+
int32_t norm = mli_math_norm_fx<in_T, int32_t>(val);
147147
*norm_shift = cast_shift - norm;
148148
return mli_math_cast_fx<in_T, out_T>(val, *norm_shift);
149149
}

lib/src/pal/vdsp/mli_math.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,9 +1508,9 @@ MLI_FORCE_INLINE vNx4int_t mli_math_norm_fx(vNx4accint_t x) {
15081508
}
15091509

15101510
template<typename in_T, typename out_T>
1511-
MLI_FORCE_INLINE out_T mli_math_norm_cast_fx(in_T val , int *norm_shift) {
1512-
int cast_shift = (sizeof(in_T) - sizeof(out_T)) * 8;
1513-
int norm = mli_math_norm_fx<in_T, in_T>(val);
1511+
MLI_FORCE_INLINE out_T mli_math_norm_cast_fx(in_T val , int32_t *norm_shift) {
1512+
int32_t cast_shift = (sizeof(in_T) - sizeof(out_T)) * 8;
1513+
int32_t norm = mli_math_norm_fx<in_T, int32_t>(val);
15141514
*norm_shift = cast_shift - norm;
15151515
return mli_math_cast_fx<in_T, out_T>(val, *norm_shift);
15161516
}

0 commit comments

Comments
 (0)