@@ -206,6 +206,47 @@ void gru(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_outputs * CONFIG_T::
206206 }
207207}
208208
209+ template <class data_T , class data2_T , class res_T , typename CONFIG_T>
210+ void gru (data_T data[CONFIG_T::n_in], data2_T h[CONFIG_T::n_units], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_units],
211+ const typename CONFIG_T::weight_t weights[3 * CONFIG_T::n_units * CONFIG_T::n_in],
212+ const typename CONFIG_T::weight_t recurrent_weights[3 * CONFIG_T::n_units * CONFIG_T::n_units],
213+ const typename CONFIG_T::bias_t bias[3 * CONFIG_T::n_units],
214+ const typename CONFIG_T::bias_t recurrent_bias[3 * CONFIG_T::n_units]) {
215+
216+ hls_register data_T x[CONFIG_T::n_in];
217+ // hls_register res_T h[CONFIG_T::n_units];
218+
219+ // #pragma unroll
220+ // for (int i = 0; i < CONFIG_T::n_units; i++) {
221+ // h[i] = 0;
222+ // }
223+
224+ // Loop depedency - cannot pipeline
225+ #pragma disable_loop_pipelining
226+ for (int t = 0 ; t < CONFIG_T::n_timesteps; t++) {
227+ // Get data at current time step
228+ #pragma unroll
229+ for (int j = 0 ; j < CONFIG_T::n_in; j++) {
230+ x[j] = data[j + t * CONFIG_T::n_in];
231+ }
232+
233+ nnet::gru_cell<data_T, res_T, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);
234+
235+ if (CONFIG_T::return_sequences) {
236+ #pragma unroll
237+ for (int i = 0 ; i < CONFIG_T::n_units; i++) {
238+ res[CONFIG_T::n_units * t + i] = h[i];
239+ }
240+ }
241+ }
242+
243+ if (!CONFIG_T::return_sequences) {
244+ #pragma unroll
245+ for (int i = 0 ; i < (CONFIG_T::n_units); i++) {
246+ res[i] = h[i];
247+ }
248+ }
249+ }
209250// ----------------------
210251// SimpleRNN
211252// ----------------------
@@ -711,6 +752,79 @@ void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[CONFIG_
711752 }
712753}
713754
755+ template <class data_T , class data2_T , class data3_T , class res_T , class CONFIG_T >
756+ void lstm (data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], data2_T hidden_state_initial[CONFIG_T::n_out],
757+ data3_T cell_state_initial[CONFIG_T::n_out], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
758+ const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out],
759+ const typename CONFIG_T::weight_t WF[CONFIG_T::n_in * CONFIG_T::n_out],
760+ const typename CONFIG_T::weight_t WC[CONFIG_T::n_in * CONFIG_T::n_out],
761+ const typename CONFIG_T::weight_t WO[CONFIG_T::n_in * CONFIG_T::n_out],
762+ const typename CONFIG_T::weight_t RWI[CONFIG_T::n_out * CONFIG_T::n_out],
763+ const typename CONFIG_T::weight_t RWF[CONFIG_T::n_out * CONFIG_T::n_out],
764+ const typename CONFIG_T::weight_t RWC[CONFIG_T::n_out * CONFIG_T::n_out],
765+ const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out],
766+ const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out],
767+ const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) {
768+ res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1 ] hls_register;
769+ res_T hidden_state_temp[CONFIG_T::n_out] hls_register;
770+ res_T cell_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1 ] hls_register;
771+ res_T cell_state_temp[CONFIG_T::n_out] hls_register;
772+ res_T h[CONFIG_T::n_out] hls_register;
773+ res_T c[CONFIG_T::n_out] hls_register;
774+ data_T in[CONFIG_T::n_in] hls_register;
775+
776+ // Set initially hidden state (output) to zero
777+ INIT_LOOP:
778+ #pragma unroll
779+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
780+ hidden_state[x][0 ] = hidden_state_initial[x];
781+ cell_state[x][0 ] = cell_state_initial[x];
782+ }
783+
784+ // Input dimension
785+ #pragma disable_loop_pipelining
786+ for (int i = 0 ; i < CONFIG_T::n_timesteps; i++) {
787+ // Data at current time step
788+ for (int x = 0 ; x < CONFIG_T::n_in; x++) {
789+ in[x] = data[x + i * CONFIG_T::n_in];
790+ }
791+
792+ // Hidden state at current time step
793+ #pragma unroll
794+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
795+ hidden_state_temp[x] = hidden_state[x][i];
796+ cell_state_temp[x] = cell_state[x][i];
797+ }
798+
799+ // Do LSTM
800+ lstm_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO,
801+ BI, BF, BC, BO);
802+
803+ // Write result
804+ #pragma unroll
805+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
806+ hidden_state[x][i + 1 ] = h[x];
807+ cell_state[x][i + 1 ] = c[x];
808+ }
809+ }
810+
811+ if (CONFIG_T::return_sequences == 0 ) {
812+ // Output when return_sequences is false
813+ #pragma unroll
814+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
815+ res[x] = hidden_state[x][CONFIG_T::n_timesteps];
816+ }
817+ } else {
818+ // Output when return_sequences is true
819+ #pragma unroll
820+ for (int x = 0 ; x < CONFIG_T::n_timesteps; x++) {
821+ for (int h = 0 ; h < CONFIG_T::n_out; h++) {
822+ res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1 ];
823+ }
824+ }
825+ }
826+ }
827+
714828} // namespace nnet
715829
716830#endif
0 commit comments