@@ -405,7 +405,7 @@ void simple_rnn_pytorch_cell(const in_T &inputs, h_T &hidden_state, h_T &hidden_
405405 // Hidden state
406406 [[intel::fpga_register]] accum_array_T hiddenCand;
407407 multiply_U<h_T, accum_array_T, typename CONFIG_T::recurrent_weight_t , CONFIG_T::n_out>(hidden_state, hiddenCand,
408- rec_kernel);
408+ rec_kernel);
409409
410410 // Hidden state bias addition
411411 [[intel::fpga_register]] accum_array_T hiddenBias;
@@ -437,7 +437,7 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
437437INIT_LOOP:
438438 #pragma unroll
439439 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
440- hidden_state[x][ 0 ] = 0 ;
440+ hidden_state[0 ][x ] = 0 ;
441441 }
442442
443443 [[intel::disable_loop_pipelining]] for (int i = 0 ; i < CONFIG_T::n_timesteps; i++) {
@@ -451,7 +451,7 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
451451 // Hidden state at current time step
452452 #pragma unroll
453453 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
454- hidden_state_temp[x] = hidden_state[x][i ];
454+ hidden_state_temp[x] = hidden_state[i][x ];
455455 }
456456
457457 // Do SimpleRNN
@@ -468,24 +468,25 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
468468 // Output when return_sequences is false
469469 #pragma unroll
470470 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
471- res[x] = hidden_state[x][ CONFIG_T::n_timesteps];
471+ res[x] = hidden_state[CONFIG_T::n_timesteps][x ];
472472 }
473473 } else {
474474 // Output when return_sequences is true
475475 #pragma unroll
476476 for (int x = 0 ; x < CONFIG_T::n_timesteps; x++) {
477477 #pragma unroll
478478 for (int h = 0 ; h < CONFIG_T::n_out; h++) {
479- res[x * CONFIG_T::n_out + h] = hidden_state[h][ x + 1 ];
479+ res[x * CONFIG_T::n_out + h] = hidden_state[x + 1 ][h ];
480480 }
481481 }
482482 }
483483}
484484
485485template <class data_T , class h_T , class res_T , typename CONFIG_T>
486- void simple_rnn_pytorch_init_state (const data_T &data, const h_T& hin, res_T &res, const typename CONFIG_T::weight_t &kernel,
487- const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
488- const typename CONFIG_T::recurrent_bias_t &rec_bias) {
486+ void simple_rnn_pytorch_init_state (const data_T &data, const h_T &hin, res_T &res, const typename CONFIG_T::weight_t &kernel,
487+ const typename CONFIG_T::recurrent_weight_t &rec_kernel,
488+ const typename CONFIG_T::bias_t &bias,
489+ const typename CONFIG_T::recurrent_bias_t &rec_bias) {
489490
490491 using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
491492
@@ -498,7 +499,7 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &re
498499INIT_LOOP:
499500 #pragma unroll
500501 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
501- hidden_state[x][ 0 ] = hin[x];
502+ hidden_state[0 ][x ] = hin[x];
502503 }
503504
504505 [[intel::disable_loop_pipelining]] for (int i = 0 ; i < CONFIG_T::n_timesteps; i++) {
@@ -512,7 +513,7 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &re
512513 // Hidden state at current time step
513514 #pragma unroll
514515 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
515- hidden_state_temp[x] = hidden_state[x][i ];
516+ hidden_state_temp[x] = hidden_state[i][x ];
516517 }
517518
518519 // Do SimpleRNN
@@ -529,15 +530,15 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &re
529530 // Output when return_sequences is false
530531 #pragma unroll
531532 for (int x = 0 ; x < CONFIG_T::n_out; x++) {
532- res[x] = hidden_state[x][ CONFIG_T::n_timesteps];
533+ res[x] = hidden_state[CONFIG_T::n_timesteps][x ];
533534 }
534535 } else {
535536 // Output when return_sequences is true
536537 #pragma unroll
537538 for (int x = 0 ; x < CONFIG_T::n_timesteps; x++) {
538539 #pragma unroll
539540 for (int h = 0 ; h < CONFIG_T::n_out; h++) {
540- res[x * CONFIG_T::n_out + h] = hidden_state[h][ x + 1 ];
541+ res[x * CONFIG_T::n_out + h] = hidden_state[x + 1 ][h ];
541542 }
542543 }
543544 }
0 commit comments