@@ -43,11 +43,11 @@ static MLI_FORCE_INLINE vNx2short_t calc_leaky_relu(
4343 pvNx2 sel = init_predicate (input > 0 );
4444 vNx2short_t neg;
4545 if ( shift > mul_hi_shift) {
46- neg = mli_math_mul_fx_high (input, scale);
47- neg = mli_math_asr_rnd_fx (neg, shift - mul_hi_shift);
46+ neg = mli_math_mul_fx_high (input, scale);
47+ neg = mli_math_asr_rnd_fx (neg, shift - mul_hi_shift);
4848 } else {
49- vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50- neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
49+ vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50+ neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
5151 }
5252
5353 return mli_math_select_fx (sel, input, neg);
@@ -110,24 +110,21 @@ static MLI_FORCE_INLINE vNx4char_t calc_leaky_relu(
110110 /* Load Input */
111111 vNx4char_t input = mli_prv_load_1vec (vec_in);
112112 vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
113- grp_pvNx2_t select = init_predicate_grp (input_cast >= in_zp);
113+ vNx4short_t cond;
114+ cond.lo = input_cast.lo >= in_zp;
115+ cond.hi = input_cast.hi >= in_zp;
116+ grp_pvNx2_t select = init_predicate_grp (cond);
114117
115118 int identity_shift = identity_params->shift ;
116- int identity_shift_left = mli_math_max_fx (-identity_shift, 0 );
117- int identity_shift_right = mli_math_max_fx (identity_shift, 0 );
118- vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(identity_params->scale , input_cast);
119- input_identity_scale = mli_math_asl_fx (input_identity_scale, identity_shift_left);
120- input_identity_scale = mli_math_asr_rnd_fx (input_identity_scale, identity_shift_right);
119+ vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, identity_params->scale );
120+ input_identity_scale = mli_math_asr_rnd_fx (input_identity_scale, identity_shift);
121121
122122 vNx4short_t output_identity = mli_math_cast_fx<vNx4int_t, vNx4short_t>(input_identity_scale);
123123 output_identity = mli_math_add_fx (output_identity, (vNx4short_t)identity_params->offset );
124124
125125 int alpha_shift = alpha_params->shift ;
126- int alpha_shift_left = mli_math_max_fx (-alpha_shift, 0 );
127- int alpha_shift_right = mli_math_max_fx (alpha_shift, 0 );
128126 vNx4int_t input_alpha_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, alpha_params->scale );
129- input_alpha_scale = mli_math_asl_fx (input_alpha_scale, alpha_shift_left);
130- input_alpha_scale = mli_math_asr_rnd_fx (input_alpha_scale, alpha_shift_right);
127+ input_alpha_scale = mli_math_asr_rnd_fx (input_alpha_scale, alpha_shift);
131128
132129 vNx4short_t output_alpha = mli_math_cast_fx<vNx4int_t, vNx4short_t>(input_alpha_scale);
133130 output_alpha = mli_math_add_fx (output_alpha, (vNx4short_t)alpha_params->offset );
@@ -178,6 +175,7 @@ static MLI_FORCE_INLINE void compute_leaky_relu_sa8_inner_loop(
178175 vec_out += remaining_part;
179176 }
180177
178+ #pragma clang loop unroll_count(2)
181179 for (int pos3 = remaining_part; pos3 < count; pos3 += num_lanes) {
182180 compute_leaky_relu (vec_in, vec_out, in_zp, identity_params, alpha_params);
183181 vec_in += num_lanes;
0 commit comments