Skip to content

Commit d908de4

Browse files
Dmitry NaumkinJaccovG
authored andcommitted
[gru] Reference version of GRU cell
1 parent ed323b6 commit d908de4

18 files changed

+2341
-21
lines changed

include/api/mli_helpers_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ extern "C" {
5454
#define KRNL_DW_D_DIM_HW1N 2 // Depthwise convolution hwc kernel depth (must be == 1)
5555
#define KRNL_DW_N_DIM_HW1N 3 // Depthwise convolution hwc output channels
5656

57+
// for Recurrent kernels
58+
#define KRNL_RNN_W_IN_ELEMS_DIM 1 // Input elements dimension of RNN weights
59+
#define KRNL_RNN_W_OUT_ELEMS_DIM 2 // Output elements dimension of RNN weights
60+
5761
/**
5862
* @brief Count Number of Elements in Tensor
5963
*

include/api/mli_kernels_api.h

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,14 +600,15 @@ mli_status mli_krn_fully_connected_sa8_sa8_sa32_ext_bias(
600600
/**
601601
* @brief Long Short Term Memory (LSTM) Cell
602602
*
603-
* @detail This kernel implements the default non-peephole implementation of long short term memory (LSTM) cell
603+
* @detail This kernel implements the default non-peephole implementation of long short term memory (LSTM) cell
604+
* with input (i), gate (g), forget (f) and out (o) gates
604605
*
605606
* This kernel implies sequential processing of the set of inputs vectors which is passed by input tensor of shape
606607
* (batch_size, N) where N is the length of the single frame. Both directions of processing (forward and backward)
607-
* are supported and defined by cfg structure. Kernel can output the bunch of results for according to each step of
608-
* processing, or only the last one in the sequence. Dense part of calculations uses scratch data from configuration
609-
* structure for results, and consequently output and previous output tensors might use the same memory if it is
610-
* acceptable to rewrite previous output data.
608+
* are supported and defined by cfg structure. Kernel can output the intermediate results of each step, or only the result
609+
* of the last step. Dense part of calculations uses scratch data from configuration structure for results,
610+
* and consequently output and previous output tensors might use the same memory if it is acceptable to rewrite
611+
* previous output data.
611612
*
612613
* For more info on primitive see MLI Documentation.
613614
*
@@ -658,6 +659,64 @@ mli_status mli_krn_lstm_cell_sa8_sa8_sa32(
658659
mli_tensor * cell,
659660
mli_tensor * out);
660661

662+
/**
663+
* @brief Gated Recurrent Unit (GRU) Cell
664+
*
665+
* @detail This kernel implements the Gated Recurrent Unit (GRU) cell with update (z), reset (r) and new (n) gates
666+
* in version where a reset gate is applied on the hidden state before matrix multiplication
667+
*
668+
* This kernel implies sequential processing of the set of inputs vectors which is passed by input tensor
669+
* of shape (batch_size, N) where N is the length of the single frame. Both directions of processing (forward and backward)
670+
* are supported and defined by cfg structure. Kernel can output the intermediate results of each step, or only the result
671+
* of the last step.
672+
*
673+
* For more info on primitive see MLI Documentation.
674+
*
675+
* @param in [I] Input feature tensor. Must be a tensor of shape (batch_size, input_elements).
676+
* @param prev_out [I] Previous output feature tensor. Must be a one-dimensional tensor of shape (out_elements).
677+
* @param weights_in [I] Input Weights tensor (set of 3 matrixes in the [z,r,n] order: 3-dimensional tensor)
678+
* @param weights_out [I] Hidden Weights tensor (set of 3 matrixes in the [z,r,n] order: 3-dimensional tensor)
679+
* @param bias [I] Biases tensor (set of 3 vectors in the [z,r,n] order: 2-dimensional tensor)
680+
* @param tanh_lut [I] LUT table structure prepared for the hyperbolic tangent activation
681+
* @param sigm_lut [I] LUT table structure prepared for sigmoid activation
682+
* @param cfg [I] RNN Configuration structure (for more info see @ref mli_rnn_cell_cfg)
683+
* @param out [O] Output feature tensor. Result will be stored here (single output or batch of outputs depending on mode)
684+
*
685+
* @return MLI status code
686+
*/
687+
mli_status mli_krn_gru_cell_fx16(
688+
const mli_tensor * in,
689+
const mli_tensor * prev_out,
690+
const mli_tensor * weights_in,
691+
const mli_tensor * weights_out,
692+
const mli_tensor * bias,
693+
const mli_lut * tanh_lut,
694+
const mli_lut * sigm_lut,
695+
const mli_rnn_cell_cfg * cfg,
696+
mli_tensor * out);
697+
698+
mli_status mli_krn_gru_cell_fx16_fx8_fx8(
699+
const mli_tensor * in,
700+
const mli_tensor * prev_out,
701+
const mli_tensor * weights_in,
702+
const mli_tensor * weights_out,
703+
const mli_tensor * bias,
704+
const mli_lut * tanh_lut,
705+
const mli_lut * sigm_lut,
706+
const mli_rnn_cell_cfg * cfg,
707+
mli_tensor * out);
708+
709+
mli_status mli_krn_gru_cell_sa8_sa8_sa32(
710+
const mli_tensor * in,
711+
const mli_tensor * prev_out,
712+
const mli_tensor * weights_in,
713+
const mli_tensor * weights_out,
714+
const mli_tensor * bias,
715+
const mli_lut * tanh_lut,
716+
const mli_lut * sigm_lut,
717+
const mli_rnn_cell_cfg * cfg,
718+
mli_tensor * out);
719+
661720
/**
662721
* @brief Basic Recurrent Neural Network Cell
663722
*

include/mli_types.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,6 @@ typedef struct {
375375
mli_rnn_results results; /**< Results to preserve.*/
376376
mli_rnn_out_activation act; /**< Output activation type. */
377377
mli_data_container scratch_data; /**< Container to keep intermediate results. */
378-
uint32_t scratch_capacity; /**< Size of a memory pointed by scratch_data field. */
379378
} mli_rnn_cell_cfg;
380379

381380

lib/mli_lib.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ set(MLI_LIB_SOURCE_FILES
4848
${MLI_LIB_CMAKE_DIR}/src/kernels/diverse/mli_krn_argmax.cc
4949
${MLI_LIB_CMAKE_DIR}/src/kernels/diverse/mli_krn_permute_fx.cc
5050
${MLI_LIB_CMAKE_DIR}/src/kernels/common/mli_krn_lstm_cell.cc
51+
${MLI_LIB_CMAKE_DIR}/src/kernels/common/mli_krn_gru_cell.cc
5152
)
5253

5354
set(MLI_LIB_PUBLIC_INCLUDES

lib/src/bricks/impl/mli_krn_rnn_dense_op_ref.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static inline void rnn_dense_op_stacked(
7171
MLI_CONV_OUT_PTR (io_T) dense_out_ptr = mli_prv_tensor_data_ptr<MLI_CONV_OUT_PTR (io_T)>(out);
7272

7373
for (int gate = 0; gate < gates_num; ++gate) {
74-
mli::krn::rnn_dense_op<io_T, w_T, b_T, acc_T, quant_T>(
74+
mli::krn::ref::rnn_dense_op<io_T, w_T, b_T, acc_T, quant_T>(
7575
inputs_ptr, weights_ptr, bias_ptr, dense_out_ptr, inputs_num, inputs_elements,
7676
out_elements, w_ch_out_mem_strides, in_to_out_quant_params,
7777
(io_T)val_limit.min, (io_T)val_limit.max);
@@ -128,23 +128,23 @@ static inline void rnn_dense_op(
128128
accu = mli::krn::bias_additive(&bias[o_idx], accu, &in_to_out_quant_params[0]);
129129

130130
for(int idx = 0; idx < inputs_num; idx++) {
131-
mli::krn::adjust_quant_params(&in_to_out_quant_params[idx], /* krn_idx= */ 0);
131+
mli::krn::ref::adjust_quant_params(&in_to_out_quant_params[idx], /* krn_idx= */ 0);
132132

133133
accu = dotprod1D(inputs[idx], &weights[idx][o_idx], accu, in_elements[idx],
134134
1, w_ch_out_mem_strides[idx]);
135135

136-
accu = mli::krn::weights_additive(&weights[idx][o_idx], accu, &in_to_out_quant_params[idx],
136+
accu = mli::krn::ref::weights_additive(&weights[idx][o_idx], accu, &in_to_out_quant_params[idx],
137137
in_elements[idx], /* height= */ 1, /* ch= */ 1, w_ch_out_mem_strides[idx],
138138
/* row_step= */ 1, /* ch_step= */ 1);
139139
accu = mli_math_add_fx(accu, other_additives[idx]);
140140
accu = mli_math_add_fx(accu, prev_step);
141141

142142
if(inputs_num - idx != 1) {
143-
prev_step = mli::krn::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
143+
prev_step = mli::krn::ref::ir_rnn_result_requantize(accu, &in_to_out_quant_params[idx],
144144
&in_to_out_quant_params[idx+1], /* krn_idx= */ 0);
145145
accu = mli_math_mul_fx<io_T, acc_T>(0, 0);
146146
} else {
147-
out_val = mli::krn::result_cast<io_T, acc_T, quant_T>(accu, &in_to_out_quant_params[idx]);
147+
out_val = mli::krn::ref::result_cast<io_T, acc_T, quant_T>(accu, &in_to_out_quant_params[idx]);
148148
}
149149
}
150150

lib/src/bricks/impl/mli_krn_rnn_dense_op_vdsp.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,79 @@ namespace mli {
2121
namespace krn {
2222
namespace vdsp {
2323

24+
static inline void adjust_weights_dim_for_rnn_dense(fx_quant_specific_params* params) {
25+
return;
26+
}
27+
28+
static inline void adjust_weights_dim_for_rnn_dense(s8asym_quant_specific_params* params) {
29+
params->weight_dim = -1;
30+
}
31+
32+
template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
33+
static inline void rnn_dense_op_stacked(
34+
const MLI_PTR (io_T) * inputs_ptr,
35+
const mli_tensor ** weights,
36+
const mli_tensor * bias,
37+
const int gates_num,
38+
const int inputs_num,
39+
const int * inputs_elements,
40+
quant_T * in_to_out_quant_params,
41+
const int * w_ch_out_mem_strides,
42+
mli_tensor * out) {
43+
44+
constexpr bool asym = std::is_same<quant_T, s8asym_quant_specific_params>::value;
45+
46+
mli_relu_cfg relu_none = {MLI_RELU_NONE};
47+
mli_minmax_t val_limit = mli_prv_get_relu_limits<io_T, asym>(&relu_none, out);
48+
49+
const MLI_PTR (w_T) weights_ptr[MLI_RNN_MAX_INPUT];
50+
uint32_t weights_shift[MLI_RNN_MAX_INPUT];
51+
52+
const int16_t * weights_scales[MLI_RNN_MAX_INPUT];
53+
const int8_t * weights_scale_frac_bits[MLI_RNN_MAX_INPUT];
54+
55+
int out_elements = mli_prv_count_elem_num_part(bias, 1);
56+
57+
for(int idx = 0; idx < inputs_num; ++idx) {
58+
weights_ptr[idx] = mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights[idx]);
59+
weights_shift[idx] = mli_prv_count_elem_num_part(weights[idx], 1);
60+
61+
weights_scales[idx] = weights[idx]->el_params.sa.scale.mem.pi16;
62+
weights_scale_frac_bits[idx] = weights[idx]->el_params.sa.scale_frac_bits.mem.pi8;
63+
64+
adjust_weights_dim_for_rnn_dense(&in_to_out_quant_params[idx]);
65+
}
66+
67+
const MLI_PTR (b_T) bias_ptr = mli_prv_tensor_data_ptr<MLI_PTR (b_T)>(bias);
68+
MLI_CONV_OUT_PTR (io_T) dense_out_ptr = mli_prv_tensor_data_ptr<MLI_CONV_OUT_PTR (io_T)>(out);
69+
70+
for (int gate = 0; gate < gates_num; ++gate) {
71+
rnn_dense_op<io_T, w_T, b_T, acc_T, quant_T>(
72+
inputs_ptr, weights_ptr, bias_ptr, dense_out_ptr, inputs_num, inputs_elements,
73+
out_elements, w_ch_out_mem_strides, in_to_out_quant_params,
74+
(io_T)val_limit.min, (io_T)val_limit.max);
75+
76+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx)
77+
weights_ptr[weight_idx] += weights_shift[weight_idx];
78+
79+
bias_ptr += out_elements;
80+
dense_out_ptr += out_elements;
81+
82+
if (asym) {
83+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
84+
weights_scales[weight_idx]++;
85+
weights_scale_frac_bits[weight_idx]++;
86+
}
87+
}
88+
}
89+
90+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx)
91+
weights_ptr[weight_idx] -= gates_num * weights_shift[weight_idx];
92+
93+
bias_ptr -= gates_num * out_elements;
94+
dense_out_ptr -= gates_num * out_elements;
95+
}
96+
2497
template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
2598
static inline void rnn_dense_op(
2699
const MLI_PTR(io_T) __restrict * inputs,

lib/src/bricks/mli_krn_rnn_dense_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace mli {
2424
namespace krn {
2525
#if !defined(MLI_BUILD_REFERENCE) && defined(__Xvec_width)
2626
using mli::krn::vdsp::rnn_dense_op;
27-
using mli::krn::ref::rnn_dense_op_stacked;
27+
using mli::krn::vdsp::rnn_dense_op_stacked;
2828

2929
#elif !defined(MLI_BUILD_REFERENCE) && defined(__FXAPI__)
3030
using mli::krn::ref::rnn_dense_op;

lib/src/bricks/mli_krn_rnn_dense_op_decl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ static MLI_FORCE_INLINE void rnn_dense_op(
9797
const io_T val_min_limit,
9898
const io_T val_max_limit);
9999

100+
template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
101+
static MLI_FORCE_INLINE void rnn_dense_op_stacked(
102+
const MLI_PTR (io_T) * inputs_ptr,
103+
const mli_tensor ** weights,
104+
const mli_tensor * bias,
105+
const int gates_num,
106+
const int inputs_num,
107+
const int * inputs_elements,
108+
quant_T * in_to_out_quant_params,
109+
const int * w_ch_out_mem_strides,
110+
mli_tensor * out);
111+
100112
} // namespace vdsp
101113

102114
} // namespace krn
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright 2021, Synopsys, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-3-Clause license found in
6+
* the LICENSE file in the root directory of this source tree.
7+
*
8+
*/
9+
10+
#include "mli_krn_gru_cell.h"
11+
12+
#include "mli_check.h"
13+
#include "mli_config.h"
14+
#include "mli_debug.h"
15+
#include "mli_helpers_api.h"
16+
#include "mli_prv_activation_lut.h"
17+
#include "mli_types.h"
18+
19+
#ifdef __cplusplus
20+
extern "C" {
21+
#endif
22+
23+
typedef mli_acc32_t mli_sa8_sa8_sa32_accu_t;
24+
typedef mli_acc40_t mli_fx16_accu_t;
25+
typedef mli_acc32_t mli_fx16_fx8_fx8_accu_t;
26+
27+
#pragma MLI_CODE_SECTION_START(".mli_lib")
28+
29+
mli_status mli_krn_gru_cell_fx16 (
30+
const mli_tensor * in,
31+
const mli_tensor * prev_out,
32+
const mli_tensor * weights_in,
33+
const mli_tensor * weights_out,
34+
const mli_tensor * bias,
35+
const mli_lut * tanh_lut,
36+
const mli_lut * sigm_lut,
37+
const mli_rnn_cell_cfg * cfg,
38+
mli_tensor * out) {
39+
mli_status ret = MLI_CHECK_STATUS(mli_chk_gru_cell_fx16
40+
(in, prev_out, weights_in, weights_out, bias, tanh_lut, sigm_lut, cfg, out), __func__);
41+
if (ret != MLI_STATUS_OK) return ret;
42+
43+
mli::krn::gru_cell_prepare_and_run<int16_t, int16_t, int16_t, mli_fx16_accu_t,
44+
mli::krn::fx_quant_specific_params>
45+
(in, prev_out, weights_in, weights_out, bias, tanh_lut, sigm_lut, cfg, out);
46+
47+
return ret;
48+
}
49+
50+
mli_status mli_krn_gru_cell_fx16_fx8_fx8 (
51+
const mli_tensor * in,
52+
const mli_tensor * prev_out,
53+
const mli_tensor * weights_in,
54+
const mli_tensor * weights_out,
55+
const mli_tensor * bias,
56+
const mli_lut * tanh_lut,
57+
const mli_lut * sigm_lut,
58+
const mli_rnn_cell_cfg * cfg,
59+
mli_tensor * out) {
60+
mli_status ret = MLI_CHECK_STATUS(mli_chk_gru_cell_fx16_fx8_fx8
61+
(in, prev_out, weights_in, weights_out, bias, tanh_lut, sigm_lut, cfg, out), __func__);
62+
if (ret != MLI_STATUS_OK) return ret;
63+
64+
mli::krn::gru_cell_prepare_and_run<int16_t, int8_t, int8_t, mli_fx16_fx8_fx8_accu_t,
65+
mli::krn::fx_quant_specific_params>
66+
(in, prev_out, weights_in, weights_out, bias, tanh_lut, sigm_lut, cfg, out);
67+
68+
return ret;
69+
}
70+
71+
mli_status mli_krn_gru_cell_sa8_sa8_sa32 (
72+
const mli_tensor * in,
73+
const mli_tensor * prev_out,
74+
const mli_tensor * weights_in,
75+
const mli_tensor * weights_out,
76+
const mli_tensor * bias,
77+
const mli_lut * tanh_lut,
78+
const mli_lut * sigm_lut,
79+
const mli_rnn_cell_cfg * cfg,
80+
mli_tensor * out) {
81+
mli_status ret = MLI_CHECK_STATUS(mli_chk_gru_cell_sa8_sa8_sa32
82+
(in, prev_out, weights_in, weights_out, bias, tanh_lut, sigm_lut, cfg, out), __func__);
83+
if (ret != MLI_STATUS_OK) return ret;
84+
85+
mli::krn::gru_cell_prepare_and_run<int8_t, int8_t, int32_t, mli_sa8_sa8_sa32_accu_t,
86+
mli::krn::s8asym_quant_specific_params>
87+
(in, prev_out, weights_in, weights_out, bias, tanh_lut, sigm_lut, cfg, out);
88+
89+
return ret;
90+
}
91+
92+
#pragma MLI_CODE_SECTION_END()
93+
94+
#ifdef __cplusplus
95+
}
96+
#endif

0 commit comments

Comments
 (0)