Skip to content

Commit 6285742

Browse files
committed
[prelu_opt: Shifting Scale left for leaky SA8.
1 parent 6259e66 commit 6285742

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

lib/src/kernels/transform/impl/mli_krn_leaky_relu_ref.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,13 @@ static MLI_FORCE_INLINE mli_status leaky_relu_sa8_run(const mli_tensor *in,
312312
s8asym_quant_params identity_params;
313313
/* Define Requantization Params for In/Out scale ratio */
314314
define_requant_params(in, out, &identity_params);
315+
int shift_left = mli_math_max_fx(-identity_params.shift, 0);
316+
identity_params.scale = mli_math_asl_fx(identity_params.scale, shift_left);
317+
identity_params.shift = mli_math_max_fx(identity_params.shift, 0);
315318
s8asym_quant_params alpha_params = leaky_relu_define_requant_params(in, slope_coeff, out, scale, &identity_params);
319+
shift_left = mli_math_max_fx(-alpha_params.shift, 0);
320+
alpha_params.scale = mli_math_asl_fx(alpha_params.scale, shift_left);
321+
alpha_params.shift = mli_math_max_fx(alpha_params.shift, 0);
316322

317323
/* Dummy Load to get num_lanes */
318324
auto input = mli_prv_load_1vec(in_ptr);

lib/src/kernels/transform/impl/mli_krn_leaky_relu_vdsp.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,15 @@ static MLI_FORCE_INLINE vNx4char_t calc_leaky_relu(
116116
grp_pvNx2_t select = init_predicate_grp(cond);
117117

118118
int identity_shift = identity_params->shift;
119-
int identity_shift_left = mli_math_max_fx(-identity_shift, 0);
120-
int identity_shift_right = mli_math_max_fx(identity_shift, 0);
121119
vNx4int_t input_identity_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, identity_params->scale);
122-
input_identity_scale = mli_math_asl_fx(input_identity_scale, identity_shift_left);
123-
input_identity_scale = mli_math_asr_rnd_fx(input_identity_scale, identity_shift_right);
120+
input_identity_scale = mli_math_asr_rnd_fx(input_identity_scale, identity_shift);
124121

125122
vNx4short_t output_identity = mli_math_cast_fx<vNx4int_t, vNx4short_t>(input_identity_scale);
126123
output_identity = mli_math_add_fx(output_identity, (vNx4short_t)identity_params->offset);
127124

128125
int alpha_shift = alpha_params->shift;
129-
int alpha_shift_left = mli_math_max_fx(-alpha_shift, 0);
130-
int alpha_shift_right = mli_math_max_fx(alpha_shift, 0);
131126
vNx4int_t input_alpha_scale = mli_math_mul_fx<vNx4short_t, vNx4int_t>(input_cast, alpha_params->scale);
132-
input_alpha_scale = mli_math_asl_fx(input_alpha_scale, alpha_shift_left);
133-
input_alpha_scale = mli_math_asr_rnd_fx(input_alpha_scale, alpha_shift_right);
127+
input_alpha_scale = mli_math_asr_rnd_fx(input_alpha_scale, alpha_shift);
134128

135129
vNx4short_t output_alpha = mli_math_cast_fx<vNx4int_t, vNx4short_t>(input_alpha_scale);
136130
output_alpha = mli_math_add_fx(output_alpha, (vNx4short_t)alpha_params->offset);

0 commit comments

Comments
 (0)