@@ -89,6 +89,7 @@ struct gru_config {
8989 // Resource reuse info
9090 static const unsigned io_type = io_parallel;
9191 static const unsigned reuse_factor = 1 ;
92+ static const bool pytorch_order = false ;
9293 static const bool store_weights_in_bram = false ;
9394
9495 // Activation
@@ -137,7 +138,10 @@ void gru_cell(const data_T &x, h_T &h, const typename CONFIG_T::weight_t &weight
137138 [[intel::fpga_register]] h_activ_array_T hadamard_r_h;
138139 #pragma unroll recurrent_unroll_factor
139140 for (int i = 0 ; i < (CONFIG_T::n_units); i++) {
140- hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
141+ if (CONFIG_T::pytorch_order)
142+ hadamard_r_h[i] = z_r_act[i] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
143+ else
144+ hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
141145 }
142146
143147 // The candidate state; X * W_{hx} + hadmard(r(t), h_(t-1)) * W_{hh} + b_{h}
@@ -156,7 +160,11 @@ void gru_cell(const data_T &x, h_T &h, const typename CONFIG_T::weight_t &weight
156160 // Update state
157161 #pragma unroll recurrent_unroll_factor
158162 for (int i = 0 ; i < (CONFIG_T::n_units); i++) {
159- h[i] = static_cast <typename h_T::value_type>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
163+ if (CONFIG_T::pytorch_order)
164+ h[i] = static_cast <typename h_T::value_type>(h_cand_act[i] * (1 - z_r_act[i + CONFIG_T::n_units]) +
165+ h[i] * z_r_act[i + CONFIG_T::n_units]);
166+ else
167+ h[i] = static_cast <typename h_T::value_type>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
160168 }
161169}
162170
@@ -328,7 +336,7 @@ void simple_rnn(const data_T &data, res_T &res, const typename CONFIG_T::weight_
328336 // Write result
329337 #pragma unroll
330338 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
331- hidden_state[i + 1 ][x ] = h[x];
339+ hidden_state[x][ i + 1 ] = h[x];
332340 }
333341 }
334342
@@ -350,6 +358,130 @@ void simple_rnn(const data_T &data, res_T &res, const typename CONFIG_T::weight_
350358 }
351359}
352360
361+ // ----------------------
362+ // SimpleRNN with pytorch biases
363+ // ----------------------
364+
365+ struct simpleRNN_pytorch_config {
366+ // Internal data type definitions
367+ typedef float weight_t ;
368+ typedef float bias_t ;
369+ typedef float accum_t ;
370+
371+ // Layer Sizes
372+ static const unsigned n_in = 1 ;
373+ static const unsigned n_out = 1 ;
374+ static const unsigned n_outputs = 1 ;
375+ static const unsigned n_timesteps = 1 ;
376+ static const bool return_sequences = false ;
377+
378+ // Resource reuse info
379+ static const unsigned io_type = io_parallel;
380+ static const unsigned reuse_factor = 1 ;
381+ static const bool store_weights_in_bram = false ;
382+
383+ // Activation
384+ template <class x_T , class y_T , class config_T > using activation_recr = nnet::activation::relu<x_T, y_T, config_T>;
385+
386+ template <class x_T , class y_T , class config_T > using activation = nnet::activation::relu<x_T, y_T, config_T>;
387+ };
388+
389+ template <class in_T , class h_T , typename CONFIG_T>
390+ void simple_rnn_pytorch_cell (const in_T &inputs, h_T &hidden_state, h_T &hidden_state_o,
391+ const typename CONFIG_T::weight_t &kernel,
392+ const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
393+ const typename CONFIG_T::recurrent_bias_t rec_bias) {
394+
395+ using accum_array_T = array<typename CONFIG_T::accum_t , CONFIG_T::n_out>;
396+
397+ // Weight multiplication
398+ [[intel::fpga_register]] accum_array_T afterW;
399+ multiply_W<in_T, accum_array_T, typename CONFIG_T::weight_t , CONFIG_T::n_in, CONFIG_T::n_out>(inputs, afterW, kernel);
400+
401+ // Bias addition
402+ [[intel::fpga_register]] accum_array_T afterBias;
403+ add_bias<accum_array_T, accum_array_T, typename CONFIG_T::bias_t , CONFIG_T::n_out>(afterW, afterBias, bias);
404+
405+ // Hidden state
406+ [[intel::fpga_register]] accum_array_T hiddenCand;
407+ multiply_U<in_T, accum_array_T, typename CONFIG_T::recurrent_weight_t , CONFIG_T::n_out>(hidden_state, hiddenCand,
408+ rec_kernel);
409+
410+ // Hidden state bias addition
411+ [[intel::fpga_register]] accum_array_T hiddenBias;
412+ add_bias<accum_array_T, accum_array_T, typename CONFIG_T::recurrent_bias_t , CONFIG_T::n_out>(hiddenCand, hiddenBias,
413+ rec_bias);
414+
415+ // Vector addition
416+ [[intel::fpga_register]] accum_array_T afterAdd;
417+ add_vectors<accum_array_T, accum_array_T, accum_array_T, CONFIG_T::n_out>(afterBias, hiddenBias, afterAdd);
418+
419+ // Activation
420+ CONFIG_T::template activation<accum_array_T, h_T, typename CONFIG_T::ACT_CONFIG_T>::activation (afterAdd, hidden_state_o);
421+ }
422+
423+ template <class data_T , class res_T , typename CONFIG_T>
424+ void simple_rnn_pytorch (const data_T &data, res_T &res, const typename CONFIG_T::weight_t &kernel,
425+ const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
426+ const typename CONFIG_T::recurrent_bias_t &rec_bias) {
427+
428+ using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
429+ using h_T = array<typename res_T::value_type, CONFIG_T::n_out>;
430+
431+ [[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1 ];
432+ [[intel::fpga_register]] h_T hidden_state_temp;
433+ [[intel::fpga_register]] h_T h;
434+ [[intel::fpga_register]] in_T in;
435+
436+ // Set initially hidden state (output) to zero
437+ INIT_LOOP:
438+ #pragma unroll
439+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
440+ hidden_state[x][0 ] = 0 ;
441+ }
442+
443+ [[intel::disable_loop_pipelining]] for (int i = 0 ; i < CONFIG_T::n_timesteps; i++) {
444+
445+ // Data at current time step
446+ #pragma unroll
447+ for (int x = 0 ; x < CONFIG_T::n_in; x++) {
448+ in[x] = data[x + i * CONFIG_T::n_in];
449+ }
450+
451+ // Hidden state at current time step
452+ #pragma unroll
453+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
454+ hidden_state_temp[x] = hidden_state[x][i];
455+ }
456+
457+ // Do SimpleRNN
458+ simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
459+
460+ // Write result
461+ #pragma unroll
462+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
463+ hidden_state[x][i + 1 ] = h[x];
464+ }
465+ }
466+
467+ if (CONFIG_T::return_sequences == 0 ) {
468+ // Output when return_sequences is false
469+ #pragma unroll
470+ for (int x = 0 ; x < CONFIG_T::n_out; x++) {
471+ res[x] = hidden_state[x][CONFIG_T::n_timesteps];
472+ }
473+ } else {
474+ // Output when return_sequences is true
475+ #pragma unroll
476+ for (int x = 0 ; x < CONFIG_T::n_timesteps; x++) {
477+ #pragma unroll
478+ for (int h = 0 ; h < CONFIG_T::n_out; h++) {
479+ res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1 ];
480+ }
481+ }
482+ }
483+ }
484+
353485// ----------------------
354486// LSTM
355487// ----------------------
0 commit comments