Skip to content

Commit 899bcbf

Browse files
committed
[prelu_opt]: Optimizing leaky_relu SA8
1 parent becb08c commit 899bcbf

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

lib/src/kernels/transform/impl/mli_krn_leaky_relu_vdsp.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ static MLI_FORCE_INLINE vNx2short_t calc_leaky_relu(
4343
pvNx2 sel = init_predicate(input > 0);
4444
vNx2short_t neg;
4545
if ( shift > mul_hi_shift) {
46-
neg = mli_math_mul_fx_high(input, scale);
47-
neg = mli_math_asr_rnd_fx(neg, shift - mul_hi_shift);
46+
neg = mli_math_mul_fx_high(input, scale);
47+
neg = mli_math_asr_rnd_fx(neg, shift - mul_hi_shift);
4848
} else {
49-
vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50-
neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
49+
vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50+
neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
5151
}
5252

5353
return mli_math_select_fx(sel, input, neg);
@@ -110,12 +110,15 @@ static MLI_FORCE_INLINE vNx4char_t calc_leaky_relu(
110110
/* Load Input */
111111
vNx4char_t input = mli_prv_load_1vec(vec_in);
112112
vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
113-
grp_pvNx2_t select = init_predicate_grp(input_cast >= in_zp);
113+
vNx4short_t cond;
114+
cond.lo = input_cast.lo >= in_zp;
115+
cond.hi = input_cast.hi >= in_zp;
116+
grp_pvNx2_t select = init_predicate_grp(cond);
114117

115118
int identity_shift = identity_params->shift;
116119
int identity_shift_left = mli_math_max_fx(-identity_shift, 0);
117120
int identity_shift_right = mli_math_max_fx(identity_shift, 0);
118-
vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(identity_params->scale, input_cast);
121+
vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, identity_params->scale);
119122
input_identity_scale = mli_math_asl_fx(input_identity_scale, identity_shift_left);
120123
input_identity_scale = mli_math_asr_rnd_fx(input_identity_scale, identity_shift_right);
121124

@@ -178,6 +181,7 @@ static MLI_FORCE_INLINE void compute_leaky_relu_sa8_inner_loop(
178181
vec_out += remaining_part;
179182
}
180183

184+
#pragma clang loop unroll_count(2)
181185
for (int pos3 = remaining_part; pos3 < count; pos3 += num_lanes) {
182186
compute_leaky_relu(vec_in, vec_out, in_zp, identity_params, alpha_params);
183187
vec_in += num_lanes;

0 commit comments

Comments
 (0)