Skip to content

Commit 7c48e3b

Browse files
authored
Merge pull request #353 from foss-for-synopsys-dwc-arc-processors/rnn_fix
New structural rnn dense implementation
2 parents ae8c6a7 + 2af54d0 commit 7c48e3b

File tree

16 files changed

+372
-134
lines changed

16 files changed

+372
-134
lines changed

lib/src/bricks/impl/mli_krn_rnn_dense_op_ref.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ static inline void rnn_dense_op(
122122
}
123123

124124
for (int o_idx = 0; o_idx < out_elements; o_idx++) {
125-
io_T out_val = 0;
125+
126126
acc_T accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
127-
acc_T prev_step = mli_math_mul_fx<io_T, acc_T>(0, 0);
127+
acc_T acc_ir = mli_math_mul_fx<io_T, acc_T>(0, 0);
128+
acc_T acc_res_ir = mli_math_mul_fx<io_T, acc_T>(0, 0);
129+
128130
accu = mli::krn::bias_additive(&bias[o_idx], accu, &in_to_out_quant_params[0]);
129131

130132
for(int idx = 0; idx < inputs_num; idx++) {
@@ -137,20 +139,14 @@ static inline void rnn_dense_op(
137139
in_elements[idx], /* height= */ 1, /* ch= */ 1, w_ch_out_mem_strides[idx],
138140
/* row_step= */ 1, /* ch_step= */ 1);
139141
accu = mli_math_add_fx(accu, other_additives[idx]);
140-
accu = mli_math_add_fx(accu, prev_step);
141-
142-
if(inputs_num - idx != 1) {
143-
prev_step = mli::krn::ref::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
144-
&in_to_out_quant_params[idx+1], /* krn_idx= */ 0);
145-
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
146-
} else {
147-
out_val = mli::krn::ref::result_cast<io_T, acc_T, quant_T>(accu, &in_to_out_quant_params[idx]);
148-
}
142+
143+
acc_ir = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx]);
144+
acc_res_ir = mli_math_add_fx(acc_res_ir, acc_ir);
145+
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
149146
}
150147

151-
out_val = MIN(out_val, val_max_limit);
152-
out_val = MAX(out_val, val_min_limit);
153-
out[o_idx] = out_val;
148+
out[o_idx] = mli::krn::ir_result_cast_relu_store<io_T, acc_T, quant_T>(acc_res_ir,
149+
&in_to_out_quant_params[inputs_num - 1], val_min_limit, val_max_limit);
154150
}
155151
}
156152

lib/src/bricks/impl/mli_krn_rnn_dense_op_vdsp.h

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,27 @@ static inline void rnn_dense_op_stacked(
9494
dense_out_ptr -= gates_num * out_elements;
9595
}
9696

97+
MLI_FORCE_INLINE vNx2accint_t mli_math_add_accus(vNx2accint_t L, vNx2accint_t R) {
98+
return mli_math_add(L, R);
99+
}
100+
101+
MLI_FORCE_INLINE vNx4accint_t mli_math_add_accus(vNx4accint_t L, vNx4accint_t R) {
102+
return mli_math_add(L, R);
103+
}
104+
105+
MLI_FORCE_INLINE vNx4accshort_t mli_math_add_accus(vNx4accshort_t L, vNx4accshort_t R) {
106+
#if (__Xvec_guard_bit_option == 0)
107+
vNx4short_t L_short = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(L);
108+
vNx4short_t R_short = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(R);
109+
110+
vNx4short_t res = mli_math_add_fx<vNx4short_t>(L_short, R_short);
111+
112+
return mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(res, (vNx4short_t)0);
113+
#else
114+
return mli_math_add(L, R);
115+
#endif
116+
}
117+
97118
template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
98119
static inline void rnn_dense_op(
99120
const MLI_PTR(io_T) __restrict * inputs,
@@ -109,12 +130,14 @@ static inline void rnn_dense_op(
109130
const io_T val_max_limit) {
110131

111132
int num_lanes = get_number_lanes<acc_T>();
133+
112134
for (int o_idx = 0; o_idx < out_elements; o_idx += num_lanes) {
113135
int remaining_ch = out_elements - o_idx;
114136
int current_chs = MIN(remaining_ch, num_lanes); // number of channels computed in this loop iteration
115137

116-
acc_T accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
117-
acc_T prev_step = mli_math_mul_fx<io_T, acc_T>(0, 0);
138+
acc_T accu = mli_prv_init_accu<acc_T>();
139+
acc_T acc_ir = mli_prv_init_accu<acc_T>();
140+
acc_T acc_res_ir = mli_prv_init_accu<acc_T>();
118141

119142
auto output_params = adjust_quant_params_v(&in_to_out_quant_params[0], 0);
120143
accu = mli::krn::bias_additive(&bias[o_idx], accu, &output_params, /* add_preshift_rnd */ false);
@@ -124,20 +147,18 @@ static inline void rnn_dense_op(
124147
output_params = adjust_quant_params_v(&in_to_out_quant_params[idx], 0);
125148
accu = dotprod_inputzp_1D_v(inputs[idx], &weights[idx][o_idx], accu, in_elements[idx],
126149
1, w_ch_out_mem_strides[idx], &in_to_out_quant_params[idx]);
127-
accu = mli_math_add(accu, prev_step);
128-
129-
if(inputs_num - idx != 1) {
130-
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], o_idx);
131-
prev_step = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
132-
&in_to_out_quant_params[idx + 1], /* krn_idx= */ 0);
133-
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
134-
} else {
135-
// Cast result to output type with scaling
136-
mli::krn::result_cast_relu_store_v(&out[o_idx], accu, &output_params,
137-
val_min_limit, val_max_limit, current_chs, /* add_preshift_rnd */ true);
138-
}
150+
151+
/* TODO: can be optimized using adjust_quant_params_v, and also optimize ir_rnn_result_requantize function */
152+
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], o_idx);
153+
acc_ir = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx]);
154+
155+
acc_res_ir = mli_math_add_accus(acc_res_ir, acc_ir);
156+
accu = mli_prv_init_accu<acc_T>();
139157
}
140158

159+
// Cast result to output type with scaling
160+
mli::krn::ir_result_cast_relu_store_v(&out[o_idx], acc_res_ir, &output_params,
161+
val_min_limit, val_max_limit, current_chs);
141162
}
142163
}
143164

lib/src/bricks/impl/mli_prv_quant_ref.h

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,12 @@ static MLI_FORCE_INLINE void result_cast_relu_store(
399399
const int16_t val_max_limit) {
400400

401401
o_T out = result_cast<o_T, acc_T, quant_T>(acc, quant_params);
402-
out = MIN(out, val_max_limit);
403-
out = MAX(out, val_min_limit);
402+
out = mli_math_min_fx(out, val_max_limit);
403+
out = mli_math_max_fx(out, val_min_limit);
404404

405405
*o_ptr = (o_T) out;
406406
}
407407

408-
409-
410408
template <typename io_T, typename acc_T, typename b_T, mli_math_type math_type>
411409
MLI_FORCE_INLINE io_T result_cast(const acc_T acc, const b_T bias, const int32_t out_mul,
412410
const conv_math_params* math_params) {
@@ -438,21 +436,64 @@ MLI_FORCE_INLINE int8_t result_cast<int8_t, mli_acc32_t, int32_t, S8ASYM_MATH>(
438436
return out_val;
439437
}
440438

439+
template <>
440+
MLI_FORCE_INLINE int16_t ir_result_cast_relu_store(
441+
const mli_acc40_t acc,
442+
const fx_quant_specific_params* math_params,
443+
const int16_t val_min_limit,
444+
const int16_t val_max_limit) {
445+
int16_t out_val = mli_math_cast_fx<mli_acc40_t, int16_t>(acc);
446+
out_val = mli_math_min_fx(out_val, val_max_limit);
447+
out_val = mli_math_max_fx(out_val, val_min_limit);
448+
return out_val;
449+
}
450+
451+
template <>
452+
MLI_FORCE_INLINE int16_t ir_result_cast_relu_store(
453+
const mli_acc32_t acc,
454+
const fx_quant_specific_params* math_params,
455+
const int16_t val_min_limit,
456+
const int16_t val_max_limit) {
457+
int16_t out_val = mli_math_cast_fx<mli_acc32_t, int16_t>(acc);
458+
out_val = mli_math_min_fx(out_val, val_max_limit);
459+
out_val = mli_math_max_fx(out_val, val_min_limit);
460+
return out_val;
461+
}
462+
463+
template <>
464+
MLI_FORCE_INLINE int8_t ir_result_cast_relu_store(
465+
const mli_acc32_t acc,
466+
const s8asym_quant_specific_params* quant_params,
467+
const int8_t val_min_limit,
468+
const int8_t val_max_limit) {
469+
470+
const int16_t out_no_offset = mli_math_cast_fx<int32_t, int16_t>(acc);
471+
int8_t out_val = mli_math_cast_fx<int16_t, int8_t>(mli_math_add_fx(out_no_offset, quant_params->out_offset), 0);
472+
473+
out_val = mli_math_min_fx(out_val, val_max_limit);
474+
out_val = mli_math_max_fx(out_val, val_min_limit);
475+
476+
return out_val;
477+
}
478+
441479
template <typename acc_T>
442-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(const acc_T acc, const fx_quant_specific_params* current_params,
443-
const fx_quant_specific_params* next_params, int krn_idx) {
444-
const int shift = current_params->out_shift - next_params->out_shift;
445-
return mli_math_acc_ashift_fx<acc_T>(acc, shift);
480+
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
481+
const acc_T acc,
482+
const fx_quant_specific_params* current_params) {
483+
const int in_to_ir_shift = current_params->out_shift;
484+
return mli_math_acc_ashift_fx<acc_T>(acc, in_to_ir_shift);
446485
}
447486

448487
template <>
449-
MLI_FORCE_INLINE mli_acc32_t ir_rnn_result_requantize(const mli_acc32_t acc, const s8asym_quant_specific_params* current_params,
450-
const s8asym_quant_specific_params* next_params, int krn_idx) {
451-
const int32_t mul = current_params->out_mul / next_params->weight_scales[krn_idx];
452-
const int shift = current_params->out_shift - next_params->weight_shifts[krn_idx];
488+
MLI_FORCE_INLINE mli_acc32_t ir_rnn_result_requantize(
489+
const mli_acc32_t acc,
490+
const s8asym_quant_specific_params* current_params) {
491+
492+
const int32_t mul = current_params->out_mul;
493+
const int in_to_ir_shift = current_params->out_shift;
453494

454495
auto accu_scaled = mli_math_mul_fx<int32_t, int64_t>(acc, mul);
455-
auto out_no_offset = mli_math_cast_fx<int64_t, int32_t>(accu_scaled, shift);
496+
auto out_no_offset = mli_math_cast_fx<int64_t, int32_t>(accu_scaled, in_to_ir_shift);
456497
return out_no_offset;
457498
}
458499

lib/src/bricks/impl/mli_prv_quant_vdsp.h

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ MLI_FORCE_INLINE vNx4short_t mli_prv_convert_sa8_fx16(
318318
const int16_t zero_point,
319319
const int16_t scale,
320320
const int shift) {
321-
int shift_right = MAX(shift, 0);
322-
int shift_left = MAX(-shift, 0);
321+
int shift_right = mli_math_max_fx(shift, 0);
322+
int shift_left = mli_math_max_fx(-shift, 0);
323323
vNx4short_t in_biased_shifted_no_zp = mli_math_sub_fx<vNx4short_t>(in_val, zero_point);
324324
vNx4int_t in_scaled = mli_math_mul_fx<vNx4short_t, vNx4int_t>(in_biased_shifted_no_zp, scale);
325325
vNx4short_t res = mli_math_cast_fx<vNx4int_t, vNx4short_t>(in_scaled, shift_right);
@@ -423,8 +423,8 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
423423

424424
accu_scaled = accu_scaled + quant_params->out_offset;
425425

426-
accu_scaled = MIN(accu_scaled, val_max_limit);
427-
accu_scaled = MAX(accu_scaled, val_min_limit);
426+
accu_scaled = mli_math_min_fx(accu_scaled, val_max_limit);
427+
accu_scaled = mli_math_max_fx(accu_scaled, val_min_limit);
428428

429429
vNx4char_t out = to_vNx4char_t(accu_scaled);
430430
mli_prv_store_n_samples(o_ptr, out, num);
@@ -442,8 +442,8 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
442442

443443
vNx4char_t out = mli_math_acc_cast_fx<vNx4char_t, vNx4accshort_t>(acc, quant_params->out_shift);
444444

445-
out = MIN(out, val_max_limit);
446-
out = MAX(out, val_min_limit);
445+
out = mli_math_min_fx(out, val_max_limit);
446+
out = mli_math_max_fx(out, val_min_limit);
447447

448448
mli_prv_store_n_samples(o_ptr, out, num);
449449
}
@@ -460,8 +460,8 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
460460

461461
vNx2short_t out = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, quant_params->out_shift);
462462

463-
out = MIN(out, val_max_limit);
464-
out = MAX(out, val_min_limit);
463+
out = mli_math_min_fx(out, val_max_limit);
464+
out = mli_math_max_fx(out, val_min_limit);
465465

466466
mli_prv_store_n_samples(o_ptr, out, num);
467467
}
@@ -478,32 +478,83 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
478478

479479
vNx4short_t out = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(acc, quant_params->out_shift);
480480

481-
out = MIN(out, val_max_limit);
482-
out = MAX(out, val_min_limit);
481+
out = mli_math_min_fx(out, val_max_limit);
482+
out = mli_math_max_fx(out, val_min_limit);
483+
484+
mli_prv_store_n_samples(o_ptr, out, num);
485+
}
486+
487+
template <>
488+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
489+
MLI_CONV_OUT_PTR(int8_t) __restrict o_ptr,
490+
vNx4accshort_t acc,
491+
const s8asym_quant_specific_out_params_v* quant_params,
492+
const int16_t val_min_limit,
493+
const int16_t val_max_limit,
494+
int num) {
495+
496+
vNx4short_t accu_scaled = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(acc);
497+
accu_scaled = mli_math_add_fx<vNx4short_t>(accu_scaled, quant_params->out_offset);
498+
499+
accu_scaled = mli_math_min_fx(accu_scaled, val_max_limit);
500+
accu_scaled = mli_math_max_fx(accu_scaled, val_min_limit);
501+
502+
vNx4char_t out = to_vNx4char_t(accu_scaled);
503+
mli_prv_store_n_samples(o_ptr, out, num);
504+
}
505+
506+
template <>
507+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
508+
MLI_CONV_OUT_PTR(int16_t) __restrict o_ptr,
509+
vNx2accint_t acc,
510+
const fx_quant_specific_params* quant_params,
511+
const int16_t val_min_limit,
512+
const int16_t val_max_limit,
513+
int num) {
514+
515+
vNx2short_t out = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc);
516+
517+
out = mli_math_min_fx(out, val_max_limit);
518+
out = mli_math_max_fx(out, val_min_limit);
519+
520+
mli_prv_store_n_samples(o_ptr, out, num);
521+
}
522+
523+
template <>
524+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
525+
MLI_CONV_OUT_PTR(int16_t) __restrict o_ptr,
526+
vNx4accint_t acc,
527+
const fx_quant_specific_params* quant_params,
528+
const int16_t val_min_limit,
529+
const int16_t val_max_limit,
530+
int num) {
531+
532+
vNx4short_t out = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(acc);
533+
534+
out = mli_math_min_fx(out, val_max_limit);
535+
out = mli_math_max_fx(out, val_min_limit);
483536

484537
mli_prv_store_n_samples(o_ptr, out, num);
485538
}
486539

487540
template <typename acc_T>
488-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(const acc_T acc, const fx_quant_specific_params* current_params,
489-
const fx_quant_specific_params* next_params, int krn_idx) {
490-
const int shift = current_params->out_shift - next_params->out_shift;
491-
int shift_right = MAX(shift, 0);
492-
int shift_left = MAX(-shift, 0);
541+
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
542+
const acc_T acc,
543+
const fx_quant_specific_params* params) {
544+
const int in_to_ir_shift = params->out_shift;
545+
int shift_right = mli_math_max_fx(in_to_ir_shift, 0);
546+
int shift_left = mli_math_max_fx(-in_to_ir_shift, 0);
493547
acc_T acc_shifted = mli_math_asl_fx(acc, shift_left);
494548
return mli_math_asr_rnd_fx<acc_T, int>(acc_shifted, shift_right);
495549
}
496550

497551
template <>
498552
MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
499553
const vNx4accshort_t acc,
500-
const s8asym_quant_specific_params* current_params,
501-
const s8asym_quant_specific_params* next_params, int krn_idx) {
502-
503-
MLI_ASSERT(krn_idx == 0);
554+
const s8asym_quant_specific_params* params) {
504555

505-
const int32_t mul = current_params->out_mul / next_params->weight_scales[0];
506-
const int shift = current_params->out_shift - next_params->weight_shifts[0];
556+
const int32_t mul = params->out_mul;
557+
const int in_to_ir_shift = params->out_shift;
507558

508559
int mul_norm = mli_math_norm_fx<int32_t, int32_t>(mul);
509560
int32_t mul_shifted = mul << mul_norm;
@@ -512,14 +563,42 @@ MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
512563
vNx4int_t acc_norm = mli_math_norm_fx<vNx4int_t, vNx4int_t>(acc_int);
513564
acc_int = mli_math_asl_fx<vNx4int_t, vNx4int_t>(acc_int, acc_norm);
514565

515-
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm + shift));
516566
vNx4int_t acc_scaled = mli_math_mul_fx_high(acc_int, mul_shifted);
517-
vNx4int_t acc_shifted = mli_math_asr_rnd_fx(acc_scaled, total_shift);
518567

568+
constexpr int mul_high_shift = 32;
569+
constexpr int max_int_shift = 30;
570+
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm - mul_high_shift + in_to_ir_shift));
571+
vNx4int_t shift_left = mli_math_max_fx(-total_shift, 0);
572+
vNx4int_t shift_right = mli_math_max_fx(total_shift, 0);
573+
574+
vNx4int_t preshift = mli_math_max_fx(shift_right - max_int_shift, 0);
575+
shift_right = shift_right - preshift;
576+
577+
vNx4int_t acc_shifted = mli_math_asr_fx(acc_scaled, preshift);
578+
acc_shifted = mli_math_asr_rnd_fx(acc_shifted, shift_right);
579+
acc_shifted = mli_math_asl_fx(acc_shifted, shift_left);
580+
581+
#if (__Xvec_guard_bit_option == 0)
582+
vNx4short_t acc_short = mli_math_cast_fx<vNx4int_t, vNx4short_t>(acc_shifted);
583+
vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0);
584+
#else
519585
vNx4int_t norm;
520586
vNx4short_t acc_short = mli_math_norm_cast_fx</*left_shift*/ false>(acc_shifted , &norm);
587+
588+
constexpr int guard_bits = 8;
589+
vNx4int_t mask = (1 << norm) - 1;
590+
vNx4int_t acc_shifted_low = acc_shifted & mask;
591+
// If the norm is more than the number of guardbits,
592+
// so the masked_acc has to be shifted, since the result is shifted with max shift equals to number of guardbits.
593+
vNx4int_t mask_shift = mli_math_max_fx(norm - guard_bits, 0);
594+
acc_shifted_low = mli_math_asr_fx(acc_shifted_low, mask_shift);
595+
596+
norm = mli_math_min_fx(norm, guard_bits);
521597
vNx4accshort_t res = mli_math_init_accu_add<vNx4short_t, vNx4accshort_t>(acc_short, (vNx4short_t)0);
522598
res = mli_math_asl_fx<vNx4accshort_t, vNx4short_t>(res, to_vNx4short_t(norm));
599+
res = mli_math_add(res, to_vNx4short_t(acc_shifted_low));
600+
#endif
601+
523602
return res;
524603
}
525604

0 commit comments

Comments
 (0)