Skip to content

Commit 8a60337

Browse files
mfarag13dzakhar
authored andcommitted
[softmax_acc]: Improve Accuarcy and Fix Issue in LUT.
1 parent 8edb47f commit 8a60337

File tree

8 files changed

+189
-88
lines changed

8 files changed

+189
-88
lines changed

lib/src/bricks/impl/mli_prv_lut_dsp.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ static MLI_FORCE_INLINE v2q15_t activation_lut_two_elem_interpolate(
4747
int shift_in = in_frac_bits - lut->in_frac_bits;
4848
// if shift amount is too high, preshift argument itself and
4949
// limit shift amount to prevent overflows
50+
constexpr int max_shift = 15;
5051
int preshift_in = mli_math_max_fx(shift_in - (int)kMaxFracBitsFx16, 0);
52+
preshift_in = mli_math_min_fx(preshift_in, max_shift);
5153
shift_in = mli_math_min_fx(shift_in, (int)kMaxFracBitsFx16);
5254

5355
v2q15_t offset = mli_prv_init_v<int16_t, v2q15_t>(lut->input_offset);
@@ -58,17 +60,25 @@ static MLI_FORCE_INLINE v2q15_t activation_lut_two_elem_interpolate(
5860

5961
/* Convert Input SA8 to FX */
6062
v2q15_t x = in;
63+
v2q15_t lut_idx;
64+
v2q15_t frac;
6165
if (convert_input) {
62-
int shift = ((int32_t) in_params->shift - in_frac_bits) + preshift_in;
63-
x = mli_prv_convert_sa8_fx16<v2q15_t, v2q15_t>(x, in_params->offset, in_params->scale, shift);
66+
int shift = (int32_t) in_params->shift - in_frac_bits;
67+
v2q31_t x_int = mli_prv_convert_sa8_fx16<v2q15_t, v2q31_t>(x, in_params->offset, in_params->scale, shift);
68+
x_int = mli_math_asr_fx(x_int, preshift_in);
69+
frac[0] = x_int[0] & mask[0];
70+
frac[1] = x_int[1] & mask[1];
71+
x_int = mli_math_asr_fx(x_int, shift_in);
72+
lut_idx[0] = mli_math_bound_range_fx(mli_math_add_fx(x_int[0], (int32_t)offset[0]), lower[0], upper[0]);
73+
lut_idx[1] = mli_math_bound_range_fx(mli_math_add_fx(x_int[1], (int32_t)offset[1]), lower[1], upper[1]);
6474
} else {
6575
x = mli_math_acc_ashift_fx(x, preshift_in);
76+
frac = x & mask;
77+
lut_idx = mli_math_add_fx(mli_math_acc_ashift_fx(x, shift_in), offset);
78+
lut_idx = mli_math_bound_range_fx(lut_idx, lower, upper);
6679
}
6780

68-
v2q15_t lut_idx = mli_math_add_fx(mli_math_acc_ashift_fx(x, shift_in), offset);
69-
lut_idx = mli_math_bound_range_fx(lut_idx, lower, upper);
7081
// perform linear interpolation
71-
v2q15_t frac = x & mask;
7282
v2q15_t res = mli_prv_init_v(lut_data[lut_idx[0]], lut_data[lut_idx[1]]);
7383
v2q15_t next = mli_prv_init_v(lut_data[lut_idx[0] + 1], lut_data[lut_idx[1] + 1]);
7484
v2q15_t diff = mli_math_sub_fx(res, next);
@@ -151,7 +161,6 @@ static MLI_FORCE_INLINE void compute_activation_lut(
151161
const struct s8asym_quant_params *in_params,
152162
struct s8asym_quant_params *out_params) {
153163

154-
MLI_ASSERT(in_frac_bits >= -1); // -1 may be required by softmax
155164
MLI_ASSERT(lut->in_frac_bits >= 0);
156165
MLI_ASSERT(lut->length >= 0);
157166
MLI_ASSERT(MLI_MAX_RANK == 4);

lib/src/bricks/impl/mli_prv_lut_ref.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ static MLI_FORCE_INLINE void compute_activation_lut(
2929
const struct s8asym_quant_params *in_params,
3030
struct s8asym_quant_params *out_params) {
3131

32-
MLI_ASSERT(in_frac_bits >= -1); // -1 may be required by softmax
3332
MLI_ASSERT(lut->in_frac_bits >= 0);
3433
MLI_ASSERT(lut->length >= 0);
3534
MLI_ASSERT(MLI_MAX_RANK == 4);
@@ -103,23 +102,23 @@ static MLI_FORCE_INLINE out_T activation_lut_one_elem_interpolate(
103102
int shift_in = in_frac_bits - lut->in_frac_bits;
104103
// if shift amount is too high, preshift argument itself and
105104
// limit shift amount to prevent overflows
105+
constexpr int max_shift = 15;
106106
int preshift_in = mli_math_max_fx(shift_in - (int)kMaxFracBitsFx16, 0);
107+
preshift_in = mli_math_min_fx(preshift_in, max_shift);
107108
shift_in = mli_math_min_fx(shift_in, (int)kMaxFracBitsFx16);
108109

109110
int16_t mask = (1 << shift_in) - 1;
110111

111112
/* Convert Input SA8 to FX */
112-
int16_t input;
113+
int32_t input;
113114
if (convert_input) {
114115
int shift = ((int32_t) in_params->shift - in_frac_bits);
115-
input = mli_prv_convert_sa8_fx16<in_T, int16_t>(in, in_params->offset, in_params->scale, shift);
116+
input = mli_prv_convert_sa8_fx16<in_T, int32_t>(in, in_params->offset, in_params->scale, shift);
116117
} else {
117118
input = in;
118119
}
119-
constexpr int max_shift = 15;
120-
preshift_in = mli_math_min_fx(preshift_in, max_shift);
121-
int16_t x = input >> preshift_in;
122-
int lut_idx = mli_math_add_fx((x >> shift_in), lut->input_offset);
120+
int32_t x = mli_math_asr_fx(input, preshift_in);
121+
int lut_idx = mli_math_add_fx(mli_math_asr_fx(x, shift_in), lut->input_offset);
123122
lut_idx = mli_math_bound_range_fx(lut_idx, 0, lut->length - 2);
124123
// perform linear interpolation
125124
int16_t frac = x & mask;
@@ -174,7 +173,7 @@ static MLI_FORCE_INLINE out_T activation_lut_one_elem_no_interpolate(
174173
input = in;
175174
}
176175
int x = (int)input;
177-
int lut_idx = mli_math_add_fx((x << -shift_in), lut->input_offset);
176+
int lut_idx = mli_math_add_fx(mli_math_asl_fx(x, -shift_in), lut->input_offset);
178177
lut_idx = mli_math_bound_range_fx(lut_idx, 0, lut->length - 1);
179178
// no interpolation
180179
int16_t res = lut_data[lut_idx];

lib/src/bricks/impl/mli_prv_lut_vdsp.h

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_interpolate(
117117
int8_t in_frac_bits,
118118
const struct s8asym_quant_params *in_params) {
119119

120-
MLI_ASSERT(in_frac_bits >= -1); // -1 may be required by softmax
121120
MLI_ASSERT(lut->in_frac_bits >= 0);
122121
MLI_ASSERT(lut->length >= 0);
123122

@@ -133,26 +132,36 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_interpolate(
133132
const MLI_PTR(short) lut_data = (const MLI_PTR(short))lut->data.mem.pi16;
134133
// if shift amount is too high, preshift argument itself and
135134
// limit shift amount to prevent overflows
135+
constexpr int max_shift = 15;
136136
int preshift_in = mli_math_max_fx(shift_in - (int)kMaxFracBitsFx16, 0);
137+
preshift_in = mli_math_min_fx(preshift_in, max_shift);
137138
shift_in = mli_math_min_fx(shift_in, (int)kMaxFracBitsFx16);
138139

139140
// input data is more precise than LUT
140141
int16_t mask = (1 << shift_in) - 1;
141142
vNx4short_t x = in;
143+
vNx4int_t lut_idx_int;
144+
vNx4short_t frac;
142145
if (convert) {
143-
int shift = ((int32_t) in_params->shift - in_frac_bits) + preshift_in;
144-
x = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4short_t>(x, in_params->offset, in_params->scale, shift);
146+
int shift = (int32_t) in_params->shift - in_frac_bits;
147+
vNx4int_t x_int = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4int_t>(x, in_params->offset, in_params->scale, shift);
148+
x_int = mli_math_asr_fx(x_int, preshift_in);
149+
frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
150+
151+
/* Calculate lut_idx */
152+
vNx4int_t lut_idx = mli_math_add_fx<vNx4int_t>(mli_math_asr_fx(x_int, shift_in), lut->input_offset);
153+
lut_idx_int = mli_math_bound_range_fx(lut_idx , 0, lut->length - 2);
154+
145155
} else {
146-
constexpr int max_shift = 15;
147-
preshift_in = mli_math_min_fx(preshift_in, max_shift);
148156
x = mli_math_asr_fx(x, preshift_in);
157+
frac = x & mask;
158+
159+
/* Calculate lut_idx */
160+
vNx4short_t lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx(x, shift_in), lut->input_offset);
161+
lut_idx = mli_math_bound_range_fx(lut_idx , 0, lut->length - 2);
162+
lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1);
149163
}
150-
vNx4short_t lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx(x, shift_in), lut->input_offset);
151-
/* Calculate lut_idx */
152-
lut_idx = mli_math_bound_range_fx(lut_idx , 0, lut->length - 2);
153-
vNx4int_t lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1);
154164

155-
vNx4short_t frac = x & mask;
156165
/* Load from LUT */
157166
vNx4short_t lut_values = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int);
158167
vNx4short_t lut_values_next = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int + 1);
@@ -162,9 +171,7 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_interpolate(
162171
mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, frac), shift_in);
163172

164173
/* Calculate O/P */
165-
vNx4short_t result = mli_math_sub_fx<vNx4short_t>(lut_values, diffs_mul_frac_cast);
166-
167-
return result;
174+
return mli_math_sub_fx<vNx4short_t>(lut_values, diffs_mul_frac_cast);
168175
}
169176

170177
template <bool convert>
@@ -174,7 +181,6 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_no_interpolate(
174181
int8_t in_frac_bits,
175182
const struct s8asym_quant_params *in_params) {
176183

177-
MLI_ASSERT(in_frac_bits >= -1); // -1 may be required by softmax
178184
MLI_ASSERT(lut->in_frac_bits >= 0);
179185
MLI_ASSERT(lut->length >= 0);
180186

@@ -213,30 +219,33 @@ static MLI_FORCE_INLINE vNx4short_t activation_lut_vec_elem_no_interpolate(
213219
template <typename io_T, bool convert>
214220
static MLI_FORCE_INLINE void load_input_and_get_lut_idx(
215221
MLI_PTR(io_T) __restrict in_ptr,
216-
vNx4short_t &vec,
217-
vNx4short_t &lut_idx,
222+
vNx4short_t &x,
223+
vNx4int_t &x_int,
218224
vNx4int_t &lut_idx_int,
219225
int16_t in_frac_bits,
220226
int preshift_in,
221227
int shift_in,
222228
const mli_lut *lut,
223229
const struct s8asym_quant_params *in_params) {
224230

225-
vec = activation_lut_load_input<io_T, vNx4short_t>(in_ptr);
226-
231+
x = activation_lut_load_input<io_T, vNx4short_t>(in_ptr);
227232
if (convert) {
228-
int shift = ((int32_t) in_params->shift - in_frac_bits) + preshift_in;
229-
vec = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4short_t>(vec, in_params->offset, in_params->scale, shift);
233+
int shift = (int32_t) in_params->shift - in_frac_bits;
234+
x_int = mli_prv_convert_sa8_fx16<vNx4short_t, vNx4int_t>(x, in_params->offset, in_params->scale, shift);
235+
x_int = mli_math_asr_fx(x_int, preshift_in);
236+
237+
/* Calculate lut_idx */
238+
vNx4int_t lut_idx = mli_math_add_fx<vNx4int_t>(mli_math_asr_fx(x_int, shift_in), lut->input_offset);
239+
lut_idx_int = mli_math_bound_range_fx(lut_idx , 0, lut->length - 2);
240+
230241
} else {
231-
constexpr int max_shift = 15;
232-
preshift_in = mli_math_min_fx(preshift_in, max_shift);
233-
vec = mli_math_asr_fx(vec, preshift_in);
234-
}
242+
x = mli_math_asr_fx(x, preshift_in);
235243

236-
/* Calculate lut_idx */
237-
lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx(vec, shift_in), lut->input_offset);
238-
lut_idx = mli_math_bound_range_fx(lut_idx , 0, lut->length - 2);
239-
lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1);
244+
/* Calculate lut_idx */
245+
vNx4short_t lut_idx = mli_math_add_fx<vNx4short_t>(mli_math_asr_fx(x, shift_in), lut->input_offset);
246+
lut_idx = mli_math_bound_range_fx(lut_idx , 0, lut->length - 2);
247+
lut_idx_int = mli_math_mul_fx<vNx4short_t, vNx4int_t>(lut_idx, 1);
248+
}
240249
}
241250

242251
template <typename io_T, bool convert>
@@ -250,7 +259,6 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
250259
const struct s8asym_quant_params *in_params,
251260
struct s8asym_quant_params *out_params) {
252261

253-
MLI_ASSERT(in_frac_bits >= -1); // -1 may be required by softmax
254262
MLI_ASSERT(lut->in_frac_bits >= 0);
255263
MLI_ASSERT(lut->length >= 0);
256264
MLI_ASSERT(MLI_MAX_RANK == 4);
@@ -271,7 +279,9 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
271279
const MLI_PTR(short) lut_data = (const MLI_PTR(short))lut->data.mem.pi16;
272280
// if shift amount is too high, preshift argument itself and
273281
// limit shift amount to prevent overflows
282+
constexpr int max_shift = 15;
274283
int preshift_in = mli_math_max_fx(shift_in - (int)kMaxFracBitsFx16, 0);
284+
preshift_in = mli_math_min_fx(preshift_in, max_shift);
275285
shift_in = mli_math_min_fx(shift_in, (int)kMaxFracBitsFx16);
276286

277287
int remaining_part = in->shape[3] & (_VDSP_NUM_8BIT_LANES - 1);
@@ -294,49 +304,61 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
294304
}
295305

296306
/* Manual software pipelining */
297-
vNx4short_t x, lut_idx;
298-
vNx4int_t lut_idx_int;
307+
vNx4short_t x, frac;
308+
vNx4int_t x_int, lut_idx_int;
299309
vNx4short_t _lut_values, _lut_values_next, _frac;
300310

301-
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx, lut_idx_int,
311+
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int, lut_idx_int,
302312
in_frac_bits, preshift_in, shift_in, lut, in_params);
303-
vNx4short_t frac = x & mask;
313+
if (convert) {
314+
frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
315+
} else {
316+
frac = x & mask;
317+
}
304318
input_ptr += _VDSP_NUM_8BIT_LANES;
305319

306320
if (in->shape[3] >= _VDSP_NUM_8BIT_LANES && !convert) {
307321
/* Load from LUT */
308322
_lut_values = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int);
309323
_lut_values_next = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int + 1);
310324

311-
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx, lut_idx_int,
325+
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int, lut_idx_int,
312326
in_frac_bits, preshift_in, shift_in, lut, in_params);
313327

314328
_frac = frac;
315-
frac = x & mask;
329+
if (convert) {
330+
frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
331+
} else {
332+
frac = x & mask;
333+
}
316334
input_ptr += _VDSP_NUM_8BIT_LANES;
317335

318336
for (int pos3 = remaining_part; pos3 < in->shape[3] - _VDSP_NUM_8BIT_LANES; pos3 += _VDSP_NUM_8BIT_LANES) {
319337
/* Load from LUT */
320338
vNx4short_t lut_values = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int);
321339
vNx4short_t lut_values_next = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int + 1);
322340

323-
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx, lut_idx_int,
341+
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int, lut_idx_int,
324342
in_frac_bits, preshift_in, shift_in, lut, in_params);
325343

326-
/* perform linear interpolation */
327-
vNx4short_t diffs = mli_math_sub_fx<vNx4short_t>(_lut_values, _lut_values_next);
328-
vNx4short_t diffs_mul_frac_cast = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(
329-
mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, _frac), shift_in);
344+
/* perform linear interpolation */
345+
vNx4short_t diffs = mli_math_sub_fx<vNx4short_t>(_lut_values, _lut_values_next);
346+
vNx4short_t diffs_mul_frac_cast = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(
347+
mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, _frac), shift_in);
330348

331-
/* Calculate O/P */
332-
vNx4short_t res = mli_math_sub_fx<vNx4short_t>(_lut_values, diffs_mul_frac_cast);
349+
/* Calculate O/P */
350+
vNx4short_t res = mli_math_sub_fx<vNx4short_t>(_lut_values, diffs_mul_frac_cast);
333351

334-
/* Store O/P */
335-
activation_lut_store_output<io_T, convert>(output_ptr, res, lut, out_params);
336-
output_ptr += _VDSP_NUM_8BIT_LANES;
352+
/* Store O/P */
353+
activation_lut_store_output<io_T, convert>(output_ptr, res, lut, out_params);
354+
output_ptr += _VDSP_NUM_8BIT_LANES;
337355

338356
_frac = frac;
339-
frac = x & mask;
357+
if (convert) {
358+
frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
359+
} else {
360+
frac = x & mask;
361+
}
340362

341363
input_ptr += _VDSP_NUM_8BIT_LANES;
342364

@@ -360,15 +382,19 @@ static MLI_FORCE_INLINE void compute_activation_lut_func(
360382
vNx4short_t lut_values = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int);
361383
vNx4short_t lut_values_next = mli_prv_gather_load_nx4_samples(lut_data, lut_idx_int + 1);
362384

363-
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, lut_idx, lut_idx_int,
385+
load_input_and_get_lut_idx<io_T, convert>(input_ptr, x, x_int, lut_idx_int,
364386
in_frac_bits, preshift_in, shift_in, lut, in_params);
365387

366388
/* perform linear interpolation */
367389
vNx4short_t diffs = mli_math_sub_fx<vNx4short_t>(lut_values, lut_values_next);
368390
vNx4short_t diffs_mul_frac_cast = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(
369391
mli_math_mul_fx<vNx4short_t, vNx4accint_t>(diffs, frac), shift_in);
370392

371-
frac = x & mask;
393+
if (convert) {
394+
frac = mli_math_cast_fx<vNx4int_t, vNx4short_t>(x_int & mask);
395+
} else {
396+
frac = x & mask;
397+
}
372398

373399
/* Calculate O/P */
374400
vNx4short_t res = mli_math_sub_fx<vNx4short_t>(lut_values, diffs_mul_frac_cast);

lib/src/bricks/impl/mli_prv_quant_dsp.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,18 @@ MLI_FORCE_INLINE v2q15_t mli_prv_convert_sa8_fx16(
286286
return mli_math_acc_cast_fx<v2q15_t, v2accum40_t>(in_scaled, shift);
287287
}
288288

289+
template<>
290+
MLI_FORCE_INLINE v2q31_t mli_prv_convert_sa8_fx16(
291+
const v2q15_t in,
292+
const int16_t zero_point,
293+
const int16_t scale,
294+
const int shift) {
295+
v2q31_t out;
296+
out[0] = mli::krn::ref::mli_prv_convert_sa8_fx16<int8_t, int32_t>(in[0], zero_point, scale, shift);
297+
out[1] = mli::krn::ref::mli_prv_convert_sa8_fx16<int8_t, int32_t>(in[1], zero_point, scale, shift);
298+
return out;
299+
}
300+
289301
template<>
290302
MLI_FORCE_INLINE v2q15_t mli_prv_convert_fx16_sa8(
291303
const v2q15_t in,

lib/src/bricks/impl/mli_prv_quant_vdsp.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,21 @@ MLI_FORCE_INLINE vNx4short_t mli_prv_convert_sa8_fx16(
310310
return res;
311311
}
312312

313+
template<>
314+
MLI_FORCE_INLINE vNx4int_t mli_prv_convert_sa8_fx16(
315+
const vNx4short_t in_val,
316+
const int16_t zero_point,
317+
const int16_t scale,
318+
const int shift) {
319+
int shift_right = mli_math_max_fx(shift, 0);
320+
int shift_left = mli_math_max_fx(-shift, 0);
321+
vNx4short_t in_biased_shifted_no_zp = mli_math_sub_fx<vNx4short_t>(in_val, zero_point);
322+
vNx4int_t in_scaled = mli_math_mul_fx<vNx4short_t, vNx4int_t>(in_biased_shifted_no_zp, scale);
323+
vNx4int_t res = mli_math_asr_rnd_fx(in_scaled, shift_right);
324+
res = mli_math_asl_fx(res, shift_left);
325+
return res;
326+
}
327+
313328
MLI_FORCE_INLINE vNx4int_t mli_prv_convert_sa8_fx32(
314329
const vNx4char_t in_val,
315330
const int16_t zero_point,

lib/src/pal/dsp/mli_math.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ MLI_FORCE_INLINE int16_t mli_math_asr_fx(int16_t acc, int shift_right) {
6868
return fx_asr_q15(acc, shift_right);
6969
}
7070

71+
template <>
72+
MLI_FORCE_INLINE v2q31_t mli_math_asr_fx(v2q31_t acc, int shift_right) {
73+
acc[0] = fx_asr_q31(acc[0], shift_right);
74+
acc[1] = fx_asr_q31(acc[1], shift_right);
75+
return acc;
76+
}
77+
7178
template <typename T>
7279
MLI_FORCE_INLINE T mli_math_limit_fx(T sign) {
7380
return sign < (T)0 ? std::numeric_limits<T>::lowest() : std::numeric_limits<T>::max();

0 commit comments

Comments
 (0)