@@ -200,6 +200,41 @@ void gru(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weig
200200 }
201201}
202202
203+ template <class data_T , class h_T , class res_T , typename CONFIG_T>
204+ void gru_init_state (const data_T &data, const h_T &hin, res_T &res, const typename CONFIG_T::weight_t &weights,
205+ const typename CONFIG_T::recurrent_weight_t &recurrent_weights, const typename CONFIG_T::bias_t &bias,
206+ const typename CONFIG_T::recurrent_bias_t &recurrent_bias) {
207+
208+ [[intel::fpga_register]] data_T x;
209+
210+ [[intel::fpga_register]] h_T h = hin;
211+
212+ // Loop depedency - cannot pipeline
213+ [[intel::disable_loop_pipelining]] for (int t = 0 ; t < CONFIG_T::n_timesteps; t++) {
214+ // Get data at current time step
215+ #pragma unroll
216+ for (int j = 0 ; j < CONFIG_T::n_in; j++) {
217+ x[j] = data[j + t * CONFIG_T::n_in];
218+ }
219+
220+ nnet::gru_cell<data_T, h_T, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);
221+
222+ if (CONFIG_T::return_sequences) {
223+ #pragma unroll
224+ for (int i = 0 ; i < CONFIG_T::n_units; i++) {
225+ res[CONFIG_T::n_units * t + i] = h[i];
226+ }
227+ }
228+ }
229+
230+ if (!CONFIG_T::return_sequences) {
231+ #pragma unroll
232+ for (int i = 0 ; i < (CONFIG_T::n_units); i++) {
233+ res[i] = h[i];
234+ }
235+ }
236+ }
237+
203238// ----------------------
204239// SimpleRNN
205240// ----------------------
@@ -561,6 +596,78 @@ void lstm(const data_T &data, res_T &res, const typename CONFIG_T::weight_i_t &W
561596 }
562597}
563598
599+ template <class data_T , class h_T , class hc_T , class res_T , class CONFIG_T >
600+ void lstm_init_state (const data_T &data, const h_T &hidden_state_initial, const hc_T &cell_state_initial, res_T &res,
601+ const typename CONFIG_T::weight_i_t &WI, const typename CONFIG_T::weight_f_t &WF,
602+ const typename CONFIG_T::weight_c_t &WC, const typename CONFIG_T::weight_o_t &WO,
603+ const typename CONFIG_T::recurrent_weight_i_t &RWI, const typename CONFIG_T::recurrent_weight_f_t &RWF,
604+ const typename CONFIG_T::recurrent_weight_c_t &RWC, const typename CONFIG_T::recurrent_weight_o_t &RWO,
605+ const typename CONFIG_T::bias_i_t &BI, const typename CONFIG_T::bias_f_t &BF,
606+ const typename CONFIG_T::bias_c_t &BC, const typename CONFIG_T::bias_o_t &BO) {
607+
608+ // Note: currently this does not support recurrent bias
609+
610+ using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
611+
612+ [[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1 ];
613+ [[intel::fpga_register]] h_T hidden_state_temp;
614+ [[intel::fpga_register]] h_T cell_state[CONFIG_T::n_timesteps + 1 ];
615+ [[intel::fpga_register]] h_T cell_state_temp; // should this be updated to a differnt type
616+ [[intel::fpga_register]] h_T h;
617+ [[intel::fpga_register]] h_T c;
618+ [[intel::fpga_register]] in_T in;
619+
620+ // Set initially hidden state (output) to zero
621+ INIT_LOOP:
622+ #pragma unroll
623+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
624+ hidden_state[0 ][x] = hidden_state_initial[x];
625+ cell_state[0 ][x] = cell_state_initial[x];
626+ }
627+
628+ // Input dimension
629+ [[intel::disable_loop_pipelining]] for (int i = 0 ; i < CONFIG_T::n_timesteps; i++) {
630+ // Data at current time step
631+ for (int x = 0 ; x < CONFIG_T::n_in; x++) {
632+ in[x] = data[x + i * CONFIG_T::n_in];
633+ }
634+
635+ // Hidden state at current time step
636+ #pragma unroll
637+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
638+ hidden_state_temp[x] = hidden_state[i][x];
639+ cell_state_temp[x] = cell_state[i][x];
640+ }
641+
642+ // Do LSTM
643+ lstm_cell<in_T, h_T, CONFIG_T>(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO, BI,
644+ BF, BC, BO);
645+
646+ // Write result
647+ #pragma unroll
648+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
649+ hidden_state[i + 1 ][x] = h[x];
650+ cell_state[i + 1 ][x] = c[x];
651+ }
652+ }
653+
654+ if (CONFIG_T::return_sequences == 0 ) {
655+ // Output when return_sequences is false
656+ #pragma unroll
657+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
658+ res[x] = hidden_state[CONFIG_T::n_timesteps][x];
659+ }
660+ } else {
661+ // Output when return_sequences is true
662+ #pragma unroll
663+ for (int x = 0 ; x < CONFIG_T::n_timesteps; x++) {
664+ for (int h = 0 ; h < CONFIG_T::n_out; h++) {
665+ res[x * CONFIG_T::n_out + h] = hidden_state[x + 1 ][h];
666+ }
667+ }
668+ }
669+ }
670+
564671} // namespace nnet
565672
566673#endif
0 commit comments