@@ -29,16 +29,14 @@ static MLI_FORCE_INLINE vNx4short_t convert_input(
2929 int16_t in_zp,
3030 int remaining_part = 0 ) {
3131
32- vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
3332
3433 if (remaining_part) {
35- /* use in_zp directly when a known issue is solved. */
36- vNx4short_t zp = mli_math_mul_fx_high (input_cast, 0 );
37- zp = mli_math_add_fx<vNx4short_t>(zp, in_zp);
38- grp_pvNx2_t predicate = init_predicate_grp (remaining_part);
39- input_cast = mli_math_select_fx (predicate, input_cast, zp);
34+ pvNx4 predicate = init_predicate (remaining_part, input);
35+ input = mli_math_select_fx (predicate, input, (vNx4char_t)in_zp);
4036 }
4137
38+ vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
39+
4240 if (convert) {
4341 input_cast = mli_math_sub_fx<vNx4short_t>(input_cast, in_zp);
4442 }
@@ -65,9 +63,7 @@ static MLI_FORCE_INLINE vNx2short_t convert_input(
6563}
6664
6765static MLI_FORCE_INLINE vNx4accint_t init_sum_acc (vNx4char_t input) {
68- // Update Accu initialization when a known issue is solved.
69- // return mli_prv_init_accu<vNx4accint_t>();
70- return mli_math_mul_fx<vNx4short_t, vNx4accint_t>(mli_math_cast_fx<vNx4char_t, vNx4short_t>(input), 0 );
66+ return mli_prv_init_accu<vNx4accint_t>();
7167}
7268
7369static MLI_FORCE_INLINE vNx2accint_t init_sum_acc (vNx2short_t input) {
@@ -135,22 +131,22 @@ static MLI_FORCE_INLINE int16_t compute_normalized_sum_square(
135131 mli_acc32_t acc_hi = mli_math_intra_sum (sum_acc_hi);
136132 mli_acc32_t acc_mid = mli_math_intra_sum (sum_acc_mid);
137133 mli_acc32_t acc_lo = mli_math_intra_sum (sum_acc_lo);
138- mli_acc40_t acc = mli_prv_init_accu< mli_acc40_t >();
139- acc = mli_math_add_fx (acc, mli_math_asl_fx (( mli_acc40_t )acc_hi, acc_hi_shift)) ;
140- acc = mli_math_add_fx (acc , mli_math_asl_fx ((mli_acc40_t )acc_mid, acc_mid_shift ));
141- acc = mli_math_add_fx (acc, ( mli_acc40_t )acc_lo );
142-
143- int norm_shift_val = mli_math_norm_fx<mli_acc40_t , int >(acc);
144- /* To Cast mli_acc40_t to int16_t */
145- norm_shift_val = (sizeof (mli_acc40_t ) - sizeof (int16_t )) * 8 - norm_shift_val;
134+
135+ typedef typename std::conditional<convert == false , mli_acc40_t , mli_acc32_t >::type acc_type ;
136+ acc_type acc = mli_math_add_fx ((acc_type)acc_lo , mli_math_asl_fx ((acc_type)acc_hi, acc_hi_shift ));
137+ acc = mli_math_add_fx (acc, mli_math_asl_fx ((acc_type)acc_mid, acc_mid_shift) );
138+
139+ int norm_shift_val = mli_math_norm_fx<acc_type , int >(acc);
140+ /* To Cast mli_acc32_t to int16_t */
141+ norm_shift_val = (sizeof (acc_type ) - sizeof (int16_t )) * 8 - norm_shift_val;
146142 /* Adjust norm_shift to even number because we are going to divide it by 2 */
147143 if ((norm_shift_val & 0x1 ) == 0x1 ) {
148144 norm_shift_val += 1 ;
149145 }
150146
151147 *norm_shift = norm_shift_val;
152148 /* Cast Sum_acc to Q7.8 to bring it to LUT input range */
153- return mli_math_cast_fx<mli_acc40_t , int16_t >(acc, norm_shift_val);
149+ return mli_math_cast_fx<acc_type , int16_t >(acc, norm_shift_val);
154150}
155151
156152template <bool convert>
@@ -161,7 +157,7 @@ static MLI_FORCE_INLINE vNx4char_t compute_normalize(
161157 int shift) {
162158
163159 int shift_right = MAX (shift, 0 );
164- int shift_left = MAX (-shift, 0 );
160+ int shift_left = MAX (-shift, 0 );
165161 vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
166162
167163 if (convert) {
@@ -182,7 +178,7 @@ static MLI_FORCE_INLINE vNx2short_t compute_normalize(
182178 int shift) {
183179
184180 int shift_right = MAX (shift, 0 );
185- int shift_left = MAX (-shift, 0 );
181+ int shift_left = MAX (-shift, 0 );
186182
187183 if (convert) {
188184 input = mli_math_sub_fx<vNx2short_t>(input, in_zp);
0 commit comments