Skip to content

Commit 2af54d0

Browse files
committed
[rnn_dense] resturcture rnn logic to have common intermediate requantization
1 parent 0c09acd commit 2af54d0

File tree

16 files changed

+316
-130
lines changed

16 files changed

+316
-130
lines changed

lib/src/bricks/impl/mli_krn_rnn_dense_op_ref.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ static inline void rnn_dense_op(
122122
}
123123

124124
for (int o_idx = 0; o_idx < out_elements; o_idx++) {
125-
io_T out_val = 0;
125+
126126
acc_T accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
127-
acc_T prev_step = mli_math_mul_fx<io_T, acc_T>(0, 0);
127+
acc_T acc_ir = mli_math_mul_fx<io_T, acc_T>(0, 0);
128+
acc_T acc_res_ir = mli_math_mul_fx<io_T, acc_T>(0, 0);
129+
128130
accu = mli::krn::bias_additive(&bias[o_idx], accu, &in_to_out_quant_params[0]);
129131

130132
for(int idx = 0; idx < inputs_num; idx++) {
@@ -137,20 +139,14 @@ static inline void rnn_dense_op(
137139
in_elements[idx], /* height= */ 1, /* ch= */ 1, w_ch_out_mem_strides[idx],
138140
/* row_step= */ 1, /* ch_step= */ 1);
139141
accu = mli_math_add_fx(accu, other_additives[idx]);
140-
accu = mli_math_add_fx(accu, prev_step);
141-
142-
if(inputs_num - idx != 1) {
143-
prev_step = mli::krn::ref::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
144-
&in_to_out_quant_params[idx+1], /* krn_idx= */ 0);
145-
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
146-
} else {
147-
out_val = mli::krn::ref::result_cast<io_T, acc_T, quant_T>(accu, &in_to_out_quant_params[idx]);
148-
}
142+
143+
acc_ir = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx]);
144+
acc_res_ir = mli_math_add_fx(acc_res_ir, acc_ir);
145+
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
149146
}
150147

151-
out_val = MIN(out_val, val_max_limit);
152-
out_val = MAX(out_val, val_min_limit);
153-
out[o_idx] = out_val;
148+
out[o_idx] = mli::krn::ir_result_cast_relu_store<io_T, acc_T, quant_T>(acc_res_ir,
149+
&in_to_out_quant_params[inputs_num - 1], val_min_limit, val_max_limit);
154150
}
155151
}
156152

lib/src/bricks/impl/mli_krn_rnn_dense_op_vdsp.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ static inline void rnn_dense_op(
136136
int current_chs = MIN(remaining_ch, num_lanes); // number of channels computed in this loop iteration
137137

138138
acc_T accu = mli_prv_init_accu<acc_T>();
139-
acc_T prev_step = mli_prv_init_accu<acc_T>();
139+
acc_T acc_ir = mli_prv_init_accu<acc_T>();
140+
acc_T acc_res_ir = mli_prv_init_accu<acc_T>();
140141

141142
auto output_params = adjust_quant_params_v(&in_to_out_quant_params[0], 0);
142143
accu = mli::krn::bias_additive(&bias[o_idx], accu, &output_params, /* add_preshift_rnd */ false);
@@ -147,20 +148,17 @@ static inline void rnn_dense_op(
147148
accu = dotprod_inputzp_1D_v(inputs[idx], &weights[idx][o_idx], accu, in_elements[idx],
148149
1, w_ch_out_mem_strides[idx], &in_to_out_quant_params[idx]);
149150

150-
accu = mli_math_add_accus(accu, prev_step);
151-
152-
if(inputs_num - idx != 1) {
153-
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], o_idx);
154-
prev_step = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
155-
&in_to_out_quant_params[idx + 1], /* krn_idx= */ 0);
156-
accu = mli_prv_init_accu<acc_T>();
157-
} else {
158-
// Cast result to output type with scaling
159-
mli::krn::result_cast_relu_store_v(&out[o_idx], accu, &output_params,
160-
val_min_limit, val_max_limit, current_chs, /* add_preshift_rnd */ true);
161-
}
151+
/* TODO: can be optimized using adjust_quant_params_v, and also optimize ir_rnn_result_requantize function */
152+
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], o_idx);
153+
acc_ir = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx]);
154+
155+
acc_res_ir = mli_math_add_accus(acc_res_ir, acc_ir);
156+
accu = mli_prv_init_accu<acc_T>();
162157
}
163158

159+
// Cast result to output type with scaling
160+
mli::krn::ir_result_cast_relu_store_v(&out[o_idx], acc_res_ir, &output_params,
161+
val_min_limit, val_max_limit, current_chs);
164162
}
165163
}
166164

lib/src/bricks/impl/mli_prv_quant_ref.h

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,12 @@ static MLI_FORCE_INLINE void result_cast_relu_store(
399399
const int16_t val_max_limit) {
400400

401401
o_T out = result_cast<o_T, acc_T, quant_T>(acc, quant_params);
402-
out = MIN(out, val_max_limit);
403-
out = MAX(out, val_min_limit);
402+
out = mli_math_min_fx(out, val_max_limit);
403+
out = mli_math_max_fx(out, val_min_limit);
404404

405405
*o_ptr = (o_T) out;
406406
}
407407

408-
409-
410408
template <typename io_T, typename acc_T, typename b_T, mli_math_type math_type>
411409
MLI_FORCE_INLINE io_T result_cast(const acc_T acc, const b_T bias, const int32_t out_mul,
412410
const conv_math_params* math_params) {
@@ -438,21 +436,64 @@ MLI_FORCE_INLINE int8_t result_cast<int8_t, mli_acc32_t, int32_t, S8ASYM_MATH>(
438436
return out_val;
439437
}
440438

439+
template <>
440+
MLI_FORCE_INLINE int16_t ir_result_cast_relu_store(
441+
const mli_acc40_t acc,
442+
const fx_quant_specific_params* math_params,
443+
const int16_t val_min_limit,
444+
const int16_t val_max_limit) {
445+
int16_t out_val = mli_math_cast_fx<mli_acc40_t, int16_t>(acc);
446+
out_val = mli_math_min_fx(out_val, val_max_limit);
447+
out_val = mli_math_max_fx(out_val, val_min_limit);
448+
return out_val;
449+
}
450+
451+
template <>
452+
MLI_FORCE_INLINE int16_t ir_result_cast_relu_store(
453+
const mli_acc32_t acc,
454+
const fx_quant_specific_params* math_params,
455+
const int16_t val_min_limit,
456+
const int16_t val_max_limit) {
457+
int16_t out_val = mli_math_cast_fx<mli_acc32_t, int16_t>(acc);
458+
out_val = mli_math_min_fx(out_val, val_max_limit);
459+
out_val = mli_math_max_fx(out_val, val_min_limit);
460+
return out_val;
461+
}
462+
463+
template <>
464+
MLI_FORCE_INLINE int8_t ir_result_cast_relu_store(
465+
const mli_acc32_t acc,
466+
const s8asym_quant_specific_params* quant_params,
467+
const int8_t val_min_limit,
468+
const int8_t val_max_limit) {
469+
470+
const int16_t out_no_offset = mli_math_cast_fx<int32_t, int16_t>(acc);
471+
int8_t out_val = mli_math_cast_fx<int16_t, int8_t>(mli_math_add_fx(out_no_offset, quant_params->out_offset), 0);
472+
473+
out_val = mli_math_min_fx(out_val, val_max_limit);
474+
out_val = mli_math_max_fx(out_val, val_min_limit);
475+
476+
return out_val;
477+
}
478+
441479
template <typename acc_T>
442-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(const acc_T acc, const fx_quant_specific_params* current_params,
443-
const fx_quant_specific_params* next_params, int krn_idx) {
444-
const int shift = current_params->out_shift - next_params->out_shift;
445-
return mli_math_acc_ashift_fx<acc_T>(acc, shift);
480+
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
481+
const acc_T acc,
482+
const fx_quant_specific_params* current_params) {
483+
const int in_to_ir_shift = current_params->out_shift;
484+
return mli_math_acc_ashift_fx<acc_T>(acc, in_to_ir_shift);
446485
}
447486

448487
template <>
449-
MLI_FORCE_INLINE mli_acc32_t ir_rnn_result_requantize(const mli_acc32_t acc, const s8asym_quant_specific_params* current_params,
450-
const s8asym_quant_specific_params* next_params, int krn_idx) {
451-
const int32_t mul = current_params->out_mul / next_params->weight_scales[krn_idx];
452-
const int shift = current_params->out_shift - next_params->weight_shifts[krn_idx];
488+
MLI_FORCE_INLINE mli_acc32_t ir_rnn_result_requantize(
489+
const mli_acc32_t acc,
490+
const s8asym_quant_specific_params* current_params) {
491+
492+
const int32_t mul = current_params->out_mul;
493+
const int in_to_ir_shift = current_params->out_shift;
453494

454495
auto accu_scaled = mli_math_mul_fx<int32_t, int64_t>(acc, mul);
455-
auto out_no_offset = mli_math_cast_fx<int64_t, int32_t>(accu_scaled, shift);
496+
auto out_no_offset = mli_math_cast_fx<int64_t, int32_t>(accu_scaled, in_to_ir_shift);
456497
return out_no_offset;
457498
}
458499

lib/src/bricks/impl/mli_prv_quant_vdsp.h

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ MLI_FORCE_INLINE vNx4short_t mli_prv_convert_sa8_fx16(
318318
const int16_t zero_point,
319319
const int16_t scale,
320320
const int shift) {
321-
int shift_right = MAX(shift, 0);
322-
int shift_left = MAX(-shift, 0);
321+
int shift_right = mli_math_max_fx(shift, 0);
322+
int shift_left = mli_math_max_fx(-shift, 0);
323323
vNx4short_t in_biased_shifted_no_zp = mli_math_sub_fx<vNx4short_t>(in_val, zero_point);
324324
vNx4int_t in_scaled = mli_math_mul_fx<vNx4short_t, vNx4int_t>(in_biased_shifted_no_zp, scale);
325325
vNx4short_t res = mli_math_cast_fx<vNx4int_t, vNx4short_t>(in_scaled, shift_right);
@@ -423,8 +423,8 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
423423

424424
accu_scaled = accu_scaled + quant_params->out_offset;
425425

426-
accu_scaled = MIN(accu_scaled, val_max_limit);
427-
accu_scaled = MAX(accu_scaled, val_min_limit);
426+
accu_scaled = mli_math_min_fx(accu_scaled, val_max_limit);
427+
accu_scaled = mli_math_max_fx(accu_scaled, val_min_limit);
428428

429429
vNx4char_t out = to_vNx4char_t(accu_scaled);
430430
mli_prv_store_n_samples(o_ptr, out, num);
@@ -442,8 +442,8 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
442442

443443
vNx4char_t out = mli_math_acc_cast_fx<vNx4char_t, vNx4accshort_t>(acc, quant_params->out_shift);
444444

445-
out = MIN(out, val_max_limit);
446-
out = MAX(out, val_min_limit);
445+
out = mli_math_min_fx(out, val_max_limit);
446+
out = mli_math_max_fx(out, val_min_limit);
447447

448448
mli_prv_store_n_samples(o_ptr, out, num);
449449
}
@@ -460,8 +460,8 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
460460

461461
vNx2short_t out = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc, quant_params->out_shift);
462462

463-
out = MIN(out, val_max_limit);
464-
out = MAX(out, val_min_limit);
463+
out = mli_math_min_fx(out, val_max_limit);
464+
out = mli_math_max_fx(out, val_min_limit);
465465

466466
mli_prv_store_n_samples(o_ptr, out, num);
467467
}
@@ -478,32 +478,83 @@ MLI_FORCE_INLINE void result_cast_relu_store_v(
478478

479479
vNx4short_t out = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(acc, quant_params->out_shift);
480480

481-
out = MIN(out, val_max_limit);
482-
out = MAX(out, val_min_limit);
481+
out = mli_math_min_fx(out, val_max_limit);
482+
out = mli_math_max_fx(out, val_min_limit);
483+
484+
mli_prv_store_n_samples(o_ptr, out, num);
485+
}
486+
487+
template <>
488+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
489+
MLI_CONV_OUT_PTR(int8_t) __restrict o_ptr,
490+
vNx4accshort_t acc,
491+
const s8asym_quant_specific_out_params_v* quant_params,
492+
const int16_t val_min_limit,
493+
const int16_t val_max_limit,
494+
int num) {
495+
496+
vNx4short_t accu_scaled = mli_math_acc_cast_fx<vNx4short_t, vNx4accshort_t>(acc);
497+
accu_scaled = mli_math_add_fx<vNx4short_t>(accu_scaled, quant_params->out_offset);
498+
499+
accu_scaled = mli_math_min_fx(accu_scaled, val_max_limit);
500+
accu_scaled = mli_math_max_fx(accu_scaled, val_min_limit);
501+
502+
vNx4char_t out = to_vNx4char_t(accu_scaled);
503+
mli_prv_store_n_samples(o_ptr, out, num);
504+
}
505+
506+
template <>
507+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
508+
MLI_CONV_OUT_PTR(int16_t) __restrict o_ptr,
509+
vNx2accint_t acc,
510+
const fx_quant_specific_params* quant_params,
511+
const int16_t val_min_limit,
512+
const int16_t val_max_limit,
513+
int num) {
514+
515+
vNx2short_t out = mli_math_acc_cast_fx<vNx2short_t, vNx2accint_t>(acc);
516+
517+
out = mli_math_min_fx(out, val_max_limit);
518+
out = mli_math_max_fx(out, val_min_limit);
519+
520+
mli_prv_store_n_samples(o_ptr, out, num);
521+
}
522+
523+
template <>
524+
MLI_FORCE_INLINE void ir_result_cast_relu_store_v(
525+
MLI_CONV_OUT_PTR(int16_t) __restrict o_ptr,
526+
vNx4accint_t acc,
527+
const fx_quant_specific_params* quant_params,
528+
const int16_t val_min_limit,
529+
const int16_t val_max_limit,
530+
int num) {
531+
532+
vNx4short_t out = mli_math_acc_cast_fx<vNx4short_t, vNx4accint_t>(acc);
533+
534+
out = mli_math_min_fx(out, val_max_limit);
535+
out = mli_math_max_fx(out, val_min_limit);
483536

484537
mli_prv_store_n_samples(o_ptr, out, num);
485538
}
486539

487540
template <typename acc_T>
488-
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(const acc_T acc, const fx_quant_specific_params* current_params,
489-
const fx_quant_specific_params* next_params, int krn_idx) {
490-
const int shift = current_params->out_shift - next_params->out_shift;
491-
int shift_right = MAX(shift, 0);
492-
int shift_left = MAX(-shift, 0);
541+
MLI_FORCE_INLINE acc_T ir_rnn_result_requantize(
542+
const acc_T acc,
543+
const fx_quant_specific_params* params) {
544+
const int in_to_ir_shift = params->out_shift;
545+
int shift_right = mli_math_max_fx(in_to_ir_shift, 0);
546+
int shift_left = mli_math_max_fx(-in_to_ir_shift, 0);
493547
acc_T acc_shifted = mli_math_asl_fx(acc, shift_left);
494548
return mli_math_asr_rnd_fx<acc_T, int>(acc_shifted, shift_right);
495549
}
496550

497551
template <>
498552
MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
499553
const vNx4accshort_t acc,
500-
const s8asym_quant_specific_params* current_params,
501-
const s8asym_quant_specific_params* next_params, int krn_idx) {
502-
503-
MLI_ASSERT(krn_idx == 0);
554+
const s8asym_quant_specific_params* params) {
504555

505-
const int32_t mul = current_params->out_mul / next_params->weight_scales[0];
506-
const int shift = current_params->out_shift - next_params->weight_shifts[0];
556+
const int32_t mul = params->out_mul;
557+
const int in_to_ir_shift = params->out_shift;
507558

508559
int mul_norm = mli_math_norm_fx<int32_t, int32_t>(mul);
509560
int32_t mul_shifted = mul << mul_norm;
@@ -516,7 +567,7 @@ MLI_FORCE_INLINE vNx4accshort_t ir_rnn_result_requantize(
516567

517568
constexpr int mul_high_shift = 32;
518569
constexpr int max_int_shift = 30;
519-
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm - mul_high_shift + shift));
570+
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(acc_norm, (mul_norm - mul_high_shift + in_to_ir_shift));
520571
vNx4int_t shift_left = mli_math_max_fx(-total_shift, 0);
521572
vNx4int_t shift_right = mli_math_max_fx(total_shift, 0);
522573

lib/src/bricks/mli_prv_quant.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ using mli::krn::ref::in_additive;
3131
using mli::krn::ref::zp_additive;
3232
using mli::krn::vdsp::bias_additive;
3333
using mli::krn::ref::result_cast;
34+
using mli::krn::ref::ir_result_cast_relu_store;
3435
using mli::krn::vdsp::ir_rnn_result_requantize;
3536
using mli::krn::ref::result_cast_relu_store;
37+
using mli::krn::vdsp::ir_result_cast_relu_store_v;
3638
using mli::krn::vdsp::result_cast_relu_store_v;
3739
using mli::krn::vdsp::mli_prv_convert_sa8_fx16;
3840
using mli::krn::vdsp::mli_prv_convert_fx16_sa8;
@@ -49,6 +51,7 @@ using mli::krn::ref::in_additive;
4951
using mli::krn::ref::zp_additive;
5052
using mli::krn::ref::bias_additive;
5153
using mli::krn::ref::result_cast;
54+
using mli::krn::ref::ir_result_cast_relu_store;
5255
using mli::krn::ref::ir_rnn_result_requantize;
5356
using mli::krn::dsp::result_cast_relu_store;
5457
using mli::krn::dsp::result_cast_relu_store_v;
@@ -67,6 +70,7 @@ using mli::krn::ref::in_additive;
6770
using mli::krn::ref::zp_additive;
6871
using mli::krn::ref::bias_additive;
6972
using mli::krn::ref::result_cast;
73+
using mli::krn::ref::ir_result_cast_relu_store;
7074
using mli::krn::ref::ir_rnn_result_requantize;
7175
using mli::krn::ref::result_cast_relu_store;
7276
using mli::krn::ref::mli_prv_convert_sa8_fx16;

0 commit comments

Comments
 (0)