@@ -156,18 +156,20 @@ static MLI_FORCE_INLINE vNx4char_t compute_normalize(
156156 int16_t in_zp,
157157 int shift) {
158158
159- int shift_right = MAX (shift, 0 );
160- int shift_left = MAX (-shift, 0 );
159+ constexpr int mul_hi_shift = 16 ;
160+ shift -= mul_hi_shift;
161+ int shift_right = mli_math_max_fx (shift, 1 );
162+ int shift_left = mli_math_max_fx (1 - shift, 0 );
161163 vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
162164
163165 if (convert) {
164166 input_cast = mli_math_sub_fx<vNx4short_t>(input_cast, in_zp);
165167 }
166168
167- vNx4accint_t res = mli_math_mul_fx<vNx4short_t, vNx4accint_t> (input_cast, scale );
168- res = mli_math_asl_fx (res, shift_left );
169+ input_cast = mli_math_asl_fx (input_cast, shift_left );
170+ vNx4short_t res = mli_math_mul_fx_high (input_cast, scale );
169171
170- return mli_math_acc_cast_fx<vNx4char_t, vNx4accint_t >(res, shift_right);
172+ return mli_math_cast_fx<vNx4short_t, vNx4char_t >(res, shift_right);
171173}
172174
173175template <bool convert>
@@ -177,8 +179,8 @@ static MLI_FORCE_INLINE vNx2short_t compute_normalize(
177179 int16_t in_zp,
178180 int shift) {
179181
180- int shift_right = MAX (shift, 0 );
181- int shift_left = MAX (-shift, 0 );
182+ int shift_right = mli_math_max_fx (shift, 0 );
183+ int shift_left = mli_math_max_fx (-shift, 0 );
182184
183185 if (convert) {
184186 input = mli_math_sub_fx<vNx2short_t>(input, in_zp);
0 commit comments