Skip to content

Commit f8b2778

Browse files
authored
Merge pull request #351 from foss-for-synopsys-dwc-arc-processors/prelu_opt
Prelu opt
2 parents becb08c + 6285742 commit f8b2778

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

lib/src/kernels/transform/impl/mli_krn_leaky_relu_ref.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,13 @@ static MLI_FORCE_INLINE mli_status leaky_relu_sa8_run(const mli_tensor *in,
312312
s8asym_quant_params identity_params;
313313
/* Define Requantization Params for In/Out scale ratio */
314314
define_requant_params(in, out, &identity_params);
315+
int shift_left = mli_math_max_fx(-identity_params.shift, 0);
316+
identity_params.scale = mli_math_asl_fx(identity_params.scale, shift_left);
317+
identity_params.shift = mli_math_max_fx(identity_params.shift, 0);
315318
s8asym_quant_params alpha_params = leaky_relu_define_requant_params(in, slope_coeff, out, scale, &identity_params);
319+
shift_left = mli_math_max_fx(-alpha_params.shift, 0);
320+
alpha_params.scale = mli_math_asl_fx(alpha_params.scale, shift_left);
321+
alpha_params.shift = mli_math_max_fx(alpha_params.shift, 0);
316322

317323
/* Dummy Load to get num_lanes */
318324
auto input = mli_prv_load_1vec(in_ptr);

lib/src/kernels/transform/impl/mli_krn_leaky_relu_vdsp.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ static MLI_FORCE_INLINE vNx2short_t calc_leaky_relu(
4343
pvNx2 sel = init_predicate(input > 0);
4444
vNx2short_t neg;
4545
if ( shift > mul_hi_shift) {
46-
neg = mli_math_mul_fx_high(input, scale);
47-
neg = mli_math_asr_rnd_fx(neg, shift - mul_hi_shift);
46+
neg = mli_math_mul_fx_high(input, scale);
47+
neg = mli_math_asr_rnd_fx(neg, shift - mul_hi_shift);
4848
} else {
49-
vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50-
neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
49+
vNx2accint_t acc = mli_math_mul_fx<vNx2short_t, vNx2accint_t>(input, scale);
50+
neg = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, shift);
5151
}
5252

5353
return mli_math_select_fx(sel, input, neg);
@@ -110,24 +110,21 @@ static MLI_FORCE_INLINE vNx4char_t calc_leaky_relu(
110110
/* Load Input */
111111
vNx4char_t input = mli_prv_load_1vec(vec_in);
112112
vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
113-
grp_pvNx2_t select = init_predicate_grp(input_cast >= in_zp);
113+
vNx4short_t cond;
114+
cond.lo = input_cast.lo >= in_zp;
115+
cond.hi = input_cast.hi >= in_zp;
116+
grp_pvNx2_t select = init_predicate_grp(cond);
114117

115118
int identity_shift = identity_params->shift;
116-
int identity_shift_left = mli_math_max_fx(-identity_shift, 0);
117-
int identity_shift_right = mli_math_max_fx(identity_shift, 0);
118-
vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(identity_params->scale, input_cast);
119-
input_identity_scale = mli_math_asl_fx(input_identity_scale, identity_shift_left);
120-
input_identity_scale = mli_math_asr_rnd_fx(input_identity_scale, identity_shift_right);
119+
vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, identity_params->scale);
120+
input_identity_scale = mli_math_asr_rnd_fx(input_identity_scale, identity_shift);
121121

122122
vNx4short_t output_identity = mli_math_cast_fx<vNx4int_t, vNx4short_t>(input_identity_scale);
123123
output_identity = mli_math_add_fx(output_identity, (vNx4short_t)identity_params->offset);
124124

125125
int alpha_shift = alpha_params->shift;
126-
int alpha_shift_left = mli_math_max_fx(-alpha_shift, 0);
127-
int alpha_shift_right = mli_math_max_fx(alpha_shift, 0);
128126
vNx4int_t input_alpha_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, alpha_params->scale);
129-
input_alpha_scale = mli_math_asl_fx(input_alpha_scale, alpha_shift_left);
130-
input_alpha_scale = mli_math_asr_rnd_fx(input_alpha_scale, alpha_shift_right);
127+
input_alpha_scale = mli_math_asr_rnd_fx(input_alpha_scale, alpha_shift);
131128

132129
vNx4short_t output_alpha = mli_math_cast_fx<vNx4int_t, vNx4short_t>(input_alpha_scale);
133130
output_alpha = mli_math_add_fx(output_alpha, (vNx4short_t)alpha_params->offset);
@@ -178,6 +175,7 @@ static MLI_FORCE_INLINE void compute_leaky_relu_sa8_inner_loop(
178175
vec_out += remaining_part;
179176
}
180177

178+
#pragma clang loop unroll_count(2)
181179
for (int pos3 = remaining_part; pos3 < count; pos3 += num_lanes) {
182180
compute_leaky_relu(vec_in, vec_out, in_zp, identity_params, alpha_params);
183181
vec_in += num_lanes;

lib/src/kernels/transform/impl/mli_krn_prelu_vdsp.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ static MLI_FORCE_INLINE vNx4char_t calc_prelu(
143143
/* Load Input */
144144
vNx4char_t input = mli_prv_load_1vec(vec_in);
145145
vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
146-
grp_pvNx2_t select = init_predicate_grp(input_cast >= in_zp);
146+
vNx4short_t cond;
147+
cond.lo = input_cast.lo >= in_zp;
148+
cond.hi = input_cast.hi >= in_zp;
149+
grp_pvNx2_t select = init_predicate_grp(cond);
147150

148151
int scale_shift = identity_params->shift;
149152
int scale_shift_left = mli_math_max_fx(-scale_shift, 0);
@@ -196,4 +199,4 @@ static MLI_FORCE_INLINE void compute_prelu(
196199
} // namespace krn
197200
} // namespace mli
198201

199-
#endif // _MLI_KRN_PRELU_VDSP_H_
202+
#endif // _MLI_KRN_PRELU_VDSP_H_

0 commit comments

Comments
 (0)