@@ -43,11 +43,11 @@ static MLI_FORCE_INLINE vNx2short_t calc_leaky_relu(
4343 pvNx2 sel = init_predicate (input > 0 );
4444 vNx2short_t neg;
4545 if ( shift > mul_hi_shift) {
46- neg = mli_math_mul_fx_high (input, scale);
47- neg = mli_math_asr_rnd_fx (neg, shift - mul_hi_shift);
46+ neg = mli_math_mul_fx_high (input, scale);
47+ neg = mli_math_asr_rnd_fx (neg, shift - mul_hi_shift);
4848 } else {
49- vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50- neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
49+ vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50+ neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
5151 }
5252
5353 return mli_math_select_fx (sel, input, neg);
@@ -110,12 +110,15 @@ static MLI_FORCE_INLINE vNx4char_t calc_leaky_relu(
110110 /* Load Input */
111111 vNx4char_t input = mli_prv_load_1vec (vec_in);
112112 vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
113- grp_pvNx2_t select = init_predicate_grp (input_cast >= in_zp);
113+ vNx4short_t cond;
114+ cond.lo = input_cast.lo >= in_zp;
115+ cond.hi = input_cast.hi >= in_zp;
116+ grp_pvNx2_t select = init_predicate_grp (cond);
114117
115118 int identity_shift = identity_params->shift ;
116119 int identity_shift_left = mli_math_max_fx (-identity_shift, 0 );
117120 int identity_shift_right = mli_math_max_fx (identity_shift, 0 );
118- vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(identity_params->scale , input_cast );
121+ vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, identity_params->scale );
119122 input_identity_scale = mli_math_asl_fx (input_identity_scale, identity_shift_left);
120123 input_identity_scale = mli_math_asr_rnd_fx (input_identity_scale, identity_shift_right);
121124
@@ -178,6 +181,7 @@ static MLI_FORCE_INLINE void compute_leaky_relu_sa8_inner_loop(
178181 vec_out += remaining_part;
179182 }
180183
184+ #pragma clang loop unroll_count(2)
181185 for (int pos3 = remaining_part; pos3 < count; pos3 += num_lanes) {
182186 compute_leaky_relu (vec_in, vec_out, in_zp, identity_params, alpha_params);
183187 vec_in += num_lanes;
0 commit comments