Skip to content

Commit a353acf

Browse files
Hakim7267JaccovG
authored andcommitted
optimize max/min sa8
1 parent 6b8f858 commit a353acf

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

lib/src/kernels/eltwise/impl/mli_krn_eltwise_vdsp.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,13 +421,19 @@ MLI_FORCE_INLINE vNx4char_t eltwise_perform_operation<vNx4char_t, vNx4char_t, EL
421421
int shift = post_op_shift - mul_hi_shift;
422422
int shift_left = mli_math_max_fx(1 - shift, 0);
423423
int shift_right = mli_math_max_fx(shift, 1);
424+
// As shift is limited by 23 the shift_right is limited by 7 so we can pre_shift left the out_offset
425+
int16_t offset = out_offset << shift_right;
426+
#ifdef ROUND_UP
427+
offset += ((1 << shift_right) >> 1);
428+
#else
429+
#error Rounding mode not supported
430+
#endif
424431
vNx4short_t max = to_vNx4short_t(mli_math_max_fx(op1, op2));
425432
max = mli_math_sub_fx(max, (vNx4short_t)in_offset1);
426433
max = mli_math_asl_fx(max, shift_left);
427434
vNx4short_t max_scaled = mli_math_mul_fx_high(max, scale_factor1);
428-
max_scaled = mli_math_asr_rnd_fx(max_scaled, shift_right);
429-
max_scaled = mli_math_add_fx(max_scaled, (vNx4short_t) out_offset);
430-
res = mli_math_cast_fx<vNx4short_t, vNx4char_t>(max_scaled);
435+
max_scaled = mli_math_add_fx(max_scaled, (vNx4short_t) offset);
436+
res = mli_math_cast_fx<vNx4short_t, vNx4char_t, false>(max_scaled, shift_right);
431437
return res;
432438
}
433439

@@ -496,13 +502,19 @@ MLI_FORCE_INLINE vNx4char_t eltwise_perform_operation<vNx4char_t, vNx4char_t, EL
496502
int shift = post_op_shift - mul_hi_shift;
497503
int shift_left = mli_math_max_fx(1 - shift, 0);
498504
int shift_right = mli_math_max_fx(shift, 1);
505+
// As shift is limited by 23 the shift_right is limited by 7 so we can pre_shift left the out_offset
506+
int16_t offset = out_offset << shift_right;
507+
#ifdef ROUND_UP
508+
offset += ((1 << shift_right) >> 1);
509+
#else
510+
#error Rounding mode not supported
511+
#endif
499512
vNx4short_t max = to_vNx4short_t(mli_math_min_fx(op1, op2));
500513
max = mli_math_sub_fx(max, (vNx4short_t)in_offset1);
501514
max = mli_math_asl_fx(max, shift_left);
502515
vNx4short_t max_scaled = mli_math_mul_fx_high(max, scale_factor1);
503-
max_scaled = mli_math_asr_rnd_fx(max_scaled, shift_right);
504-
max_scaled = mli_math_add_fx(max_scaled, (vNx4short_t) out_offset);
505-
res = mli_math_cast_fx<vNx4short_t, vNx4char_t>(max_scaled);
516+
max_scaled = mli_math_add_fx(max_scaled, (vNx4short_t) offset);
517+
res = mli_math_cast_fx<vNx4short_t, vNx4char_t, false>(max_scaled, shift_right);
506518
return res;
507519

508520
}

lib/src/pal/mli_math.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ template <typename io_T> MLI_FORCE_INLINE io_T mli_math_ashift_right_fx(io_T in_
3333
template < typename out_T > MLI_FORCE_INLINE out_T mli_math_cast_ptr_to_scalar_fx(void *src);
3434
template < typename in_T > MLI_FORCE_INLINE void *mli_math_cast_scalar_to_ptr_fx(in_T src);
3535

36-
template <typename in_T, typename out_T> MLI_FORCE_INLINE out_T mli_math_cast_fx(in_T in_val, int shift_right);
36+
template <typename in_T, typename out_T, bool round = true > MLI_FORCE_INLINE out_T mli_math_cast_fx(in_T in_val, int shift_right);
3737
template <typename in_T, typename out_T> MLI_FORCE_INLINE out_T mli_math_cast_fx(in_T in_val);
3838

3939
#if defined(__Xvec_width) && !defined(MLI_BUILD_REFERENCE)

lib/src/pal/vdsp/mli_math.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,15 @@ MLI_FORCE_INLINE vNx4short_t mli_math_cast_fx(vNx4short_t in_val, int shift_righ
10351035
return acc;
10361036
}
10371037

1038+
template<>
1039+
MLI_FORCE_INLINE vNx4char_t mli_math_cast_fx<vNx4short_t, vNx4char_t, false >(vNx4short_t in_val, int shift_right) {
1040+
MLI_EXTRA_ASSERT(shift_right >= 0);
1041+
vNx4short_t acc = in_val;
1042+
acc = mli_math_asr_fx(acc, shift_right);
1043+
acc = mli_math_bound_range_fx(acc, INT8_MIN, INT8_MAX);
1044+
return to_vNx4char_t(acc);
1045+
}
1046+
10381047
template<>
10391048
MLI_FORCE_INLINE vNx4char_t mli_math_cast_fx(vNx4short_t in_val, int shift_right) {
10401049
MLI_EXTRA_ASSERT(shift_right >= 0);

0 commit comments

Comments
 (0)