Skip to content

Commit c8029dd

Browse files
committed
fix order of indices for pytorch simple RNN oneAPI
1 parent 77e78e8 commit c8029dd

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ void simple_rnn_pytorch_cell(const in_T &inputs, h_T &hidden_state, h_T &hidden_
405405
// Hidden state
406406
[[intel::fpga_register]] accum_array_T hiddenCand;
407407
multiply_U<h_T, accum_array_T, typename CONFIG_T::recurrent_weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
408-
rec_kernel);
408+
rec_kernel);
409409

410410
// Hidden state bias addition
411411
[[intel::fpga_register]] accum_array_T hiddenBias;
@@ -437,7 +437,7 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
437437
INIT_LOOP:
438438
#pragma unroll
439439
for (int x = 0; x < CONFIG_T::n_out; x++) {
440-
hidden_state[x][0] = 0;
440+
hidden_state[0][x] = 0;
441441
}
442442

443443
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
@@ -451,7 +451,7 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
451451
// Hidden state at current time step
452452
#pragma unroll
453453
for (int x = 0; x < CONFIG_T::n_out; x++) {
454-
hidden_state_temp[x] = hidden_state[x][i];
454+
hidden_state_temp[x] = hidden_state[i][x];
455455
}
456456

457457
// Do SimpleRNN
@@ -468,24 +468,25 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
468468
// Output when return_sequences is false
469469
#pragma unroll
470470
for (int x = 0; x < CONFIG_T::n_out; x++) {
471-
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
471+
res[x] = hidden_state[CONFIG_T::n_timesteps][x];
472472
}
473473
} else {
474474
// Output when return_sequences is true
475475
#pragma unroll
476476
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
477477
#pragma unroll
478478
for (int h = 0; h < CONFIG_T::n_out; h++) {
479-
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
479+
res[x * CONFIG_T::n_out + h] = hidden_state[x + 1][h];
480480
}
481481
}
482482
}
483483
}
484484

485485
template <class data_T, class h_T, class res_T, typename CONFIG_T>
486-
void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &res, const typename CONFIG_T::weight_t &kernel,
487-
const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
488-
const typename CONFIG_T::recurrent_bias_t &rec_bias) {
486+
void simple_rnn_pytorch_init_state(const data_T &data, const h_T &hin, res_T &res, const typename CONFIG_T::weight_t &kernel,
487+
const typename CONFIG_T::recurrent_weight_t &rec_kernel,
488+
const typename CONFIG_T::bias_t &bias,
489+
const typename CONFIG_T::recurrent_bias_t &rec_bias) {
489490

490491
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
491492

@@ -498,7 +499,7 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &re
498499
INIT_LOOP:
499500
#pragma unroll
500501
for (int x = 0; x < CONFIG_T::n_out; x++) {
501-
hidden_state[x][0] = hin[x];
502+
hidden_state[0][x] = hin[x];
502503
}
503504

504505
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
@@ -512,7 +513,7 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &re
512513
// Hidden state at current time step
513514
#pragma unroll
514515
for (int x = 0; x < CONFIG_T::n_out; x++) {
515-
hidden_state_temp[x] = hidden_state[x][i];
516+
hidden_state_temp[x] = hidden_state[i][x];
516517
}
517518

518519
// Do SimpleRNN
@@ -529,15 +530,15 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &re
529530
// Output when return_sequences is false
530531
#pragma unroll
531532
for (int x = 0; x < CONFIG_T::n_out; x++) {
532-
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
533+
res[x] = hidden_state[CONFIG_T::n_timesteps][x];
533534
}
534535
} else {
535536
// Output when return_sequences is true
536537
#pragma unroll
537538
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
538539
#pragma unroll
539540
for (int h = 0; h < CONFIG_T::n_out; h++) {
540-
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
541+
res[x * CONFIG_T::n_out + h] = hidden_state[x + 1][h];
541542
}
542543
}
543544
}

0 commit comments

Comments
 (0)