Skip to content

Commit 4555ea8

Browse files
mfarag13JaccovG
authored andcommitted
[l2_norm_opt]: Optimizing l2_norm SA8 for limitied zp range.
1 parent 46a8b1f commit 4555ea8

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

lib/src/kernels/transform/impl/mli_krn_l2_normalize_vdsp.h

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,14 @@ static MLI_FORCE_INLINE vNx4short_t convert_input(
2929
int16_t in_zp,
3030
int remaining_part = 0) {
3131

32-
vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
3332

3433
if (remaining_part) {
35-
/* use in_zp directly when a known issue is solved. */
36-
vNx4short_t zp = mli_math_mul_fx_high(input_cast, 0);
37-
zp = mli_math_add_fx<vNx4short_t>(zp, in_zp);
38-
grp_pvNx2_t predicate = init_predicate_grp(remaining_part);
39-
input_cast = mli_math_select_fx(predicate, input_cast, zp);
34+
pvNx4 predicate = init_predicate(remaining_part, input);
35+
input = mli_math_select_fx(predicate, input, (vNx4char_t)in_zp);
4036
}
4137

38+
vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
39+
4240
if (convert) {
4341
input_cast = mli_math_sub_fx<vNx4short_t>(input_cast, in_zp);
4442
}
@@ -65,9 +63,7 @@ static MLI_FORCE_INLINE vNx2short_t convert_input(
6563
}
6664

6765
static MLI_FORCE_INLINE vNx4accint_t init_sum_acc(vNx4char_t input) {
68-
// Update Accu initialization when a known issue is solved.
69-
// return mli_prv_init_accu<vNx4accint_t>();
70-
return mli_math_mul_fx<vNx4short_t, vNx4accint_t>(mli_math_cast_fx<vNx4char_t, vNx4short_t>(input), 0);
66+
return mli_prv_init_accu<vNx4accint_t>();
7167
}
7268

7369
static MLI_FORCE_INLINE vNx2accint_t init_sum_acc(vNx2short_t input) {
@@ -135,22 +131,22 @@ static MLI_FORCE_INLINE int16_t compute_normalized_sum_square(
135131
mli_acc32_t acc_hi = mli_math_intra_sum(sum_acc_hi);
136132
mli_acc32_t acc_mid = mli_math_intra_sum(sum_acc_mid);
137133
mli_acc32_t acc_lo = mli_math_intra_sum(sum_acc_lo);
138-
mli_acc40_t acc = mli_prv_init_accu<mli_acc40_t>();
139-
acc = mli_math_add_fx(acc, mli_math_asl_fx((mli_acc40_t)acc_hi, acc_hi_shift));
140-
acc = mli_math_add_fx(acc, mli_math_asl_fx((mli_acc40_t)acc_mid, acc_mid_shift));
141-
acc = mli_math_add_fx(acc, (mli_acc40_t)acc_lo);
142-
143-
int norm_shift_val = mli_math_norm_fx<mli_acc40_t, int>(acc);
144-
/* To Cast mli_acc40_t to int16_t */
145-
norm_shift_val = (sizeof(mli_acc40_t) - sizeof(int16_t)) * 8 - norm_shift_val;
134+
135+
typedef typename std::conditional<convert == false, mli_acc40_t, mli_acc32_t>::type acc_type;
136+
acc_type acc = mli_math_add_fx((acc_type)acc_lo, mli_math_asl_fx((acc_type)acc_hi, acc_hi_shift));
137+
acc = mli_math_add_fx(acc, mli_math_asl_fx((acc_type)acc_mid, acc_mid_shift));
138+
139+
int norm_shift_val = mli_math_norm_fx<acc_type, int>(acc);
140+
/* To Cast mli_acc32_t to int16_t */
141+
norm_shift_val = (sizeof(acc_type) - sizeof(int16_t)) * 8 - norm_shift_val;
146142
/* Adjust norm_shift to even number because we are going to divide it by 2 */
147143
if ((norm_shift_val & 0x1) == 0x1) {
148144
norm_shift_val += 1;
149145
}
150146

151147
*norm_shift = norm_shift_val;
152148
/* Cast Sum_acc to Q7.8 to bring it to LUT input range */
153-
return mli_math_cast_fx<mli_acc40_t, int16_t>(acc, norm_shift_val);
149+
return mli_math_cast_fx<acc_type, int16_t>(acc, norm_shift_val);
154150
}
155151

156152
template<bool convert>
@@ -161,7 +157,7 @@ static MLI_FORCE_INLINE vNx4char_t compute_normalize(
161157
int shift) {
162158

163159
int shift_right = MAX(shift, 0);
164-
int shift_left = MAX(-shift, 0);
160+
int shift_left = MAX(-shift, 0);
165161
vNx4short_t input_cast = mli_math_cast_fx<vNx4char_t, vNx4short_t>(input);
166162

167163
if (convert) {
@@ -182,7 +178,7 @@ static MLI_FORCE_INLINE vNx2short_t compute_normalize(
182178
int shift) {
183179

184180
int shift_right = MAX(shift, 0);
185-
int shift_left = MAX(-shift, 0);
181+
int shift_left = MAX(-shift, 0);
186182

187183
if (convert) {
188184
input = mli_math_sub_fx<vNx2short_t>(input, in_zp);

0 commit comments

Comments
 (0)