Skip to content

Commit 925fe64

Browse files
authored
Merge pull request #7 from jmitrevs/initialRecurr-simpleRNN
Fix pytorch simple RNN for oneAPI; add initial state version for Quartus and oneAPI
2 parents cf9c726 + c8029dd commit 925fe64

File tree

5 files changed

+159
-14
lines changed

5 files changed

+159
-14
lines changed

hls4ml/backends/oneapi/passes/recurrent_templates.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def format(self, node):
333333
typedef {weight_t.name} weight_t;
334334
typedef {bias_t.name} bias_t;
335335
typedef {recurrent_weight_t.name} recurrent_weight_t;
336+
typedef {recurrent_bias_t.name} recurrent_bias_t;
336337
337338
typedef {act_t} ACT_CONFIG_T;
338339
template<class x_T, class y_T, class config_T>
@@ -350,6 +351,9 @@ def format(self, node):
350351
simple_rnn_pytorch_function_template = (
351352
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
352353
)
354+
simple_rnn_pytorch_function_initial_state_template = (
355+
'nnet::simple_rnn_pytorch_init_state<{input_t}, {h_t}, {output_t}, {config}>({input}, {init_state}, {output}, {weights});'
356+
)
353357

354358

355359
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -395,10 +399,17 @@ def __init__(self):
395399

396400
def format(self, node):
397401
params = self._default_function_params(node)
402+
if params['pass_initial_states'] == 'true':
403+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
404+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
405+
398406
if node.get_attr('pytorch', False):
399-
self.template = simple_rnn_pytorch_function_template
407+
if params['pass_initial_states'] == 'true':
408+
template = simple_rnn_pytorch_function_initial_state_template
409+
else:
410+
template = simple_rnn_pytorch_function_template
400411
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
401412
else:
402-
self.template = simple_rnn_function_template
413+
template = simple_rnn_function_template
403414
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
404-
return self.template.format(**params)
415+
return template.format(**params)

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ def format(self, node):
285285
simple_rnn_pytorch_function_template = (
286286
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
287287
)
288+
simple_rnn_pytorch_function_initial_state_template = (
289+
'nnet::simple_rnn_pytorch<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {weights});'
290+
)
288291

289292

290293
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -326,13 +329,20 @@ def format(self, node):
326329
class SimpleRNNFunctionTemplate(FunctionCallTemplate):
327330
def __init__(self):
328331
super().__init__(SimpleRNN, include_header=recurrent_include_list)
329-
self.template = simple_rnn_function_template
330332

331333
def format(self, node):
332334
params = self._default_function_params(node)
335+
if params['pass_initial_states'] == 'true':
336+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
337+
params['input2'] = node.get_input_variable(node.inputs[1]).name
338+
333339
if node.get_attr('pytorch', False):
334-
self.template = simple_rnn_pytorch_function_template
340+
if params['pass_initial_states'] == 'true':
341+
template = simple_rnn_pytorch_function_initial_state_template
342+
else:
343+
template = simple_rnn_pytorch_function_template
335344
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
336345
else:
346+
template = simple_rnn_function_template
337347
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
338-
return self.template.format(**params)
348+
return template.format(**params)

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ void simple_rnn_pytorch_cell(const in_T &inputs, h_T &hidden_state, h_T &hidden_
404404

405405
// Hidden state
406406
[[intel::fpga_register]] accum_array_T hiddenCand;
407-
multiply_U<in_T, accum_array_T, typename CONFIG_T::recurrent_weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
408-
rec_kernel);
407+
multiply_U<h_T, accum_array_T, typename CONFIG_T::recurrent_weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
408+
rec_kernel);
409409

410410
// Hidden state bias addition
411411
[[intel::fpga_register]] accum_array_T hiddenBias;
@@ -437,7 +437,69 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
437437
INIT_LOOP:
438438
#pragma unroll
439439
for (int x = 0; x < CONFIG_T::n_out; x++) {
440-
hidden_state[x][0] = 0;
440+
hidden_state[0][x] = 0;
441+
}
442+
443+
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
444+
445+
// Data at current time step
446+
#pragma unroll
447+
for (int x = 0; x < CONFIG_T::n_in; x++) {
448+
in[x] = data[x + i * CONFIG_T::n_in];
449+
}
450+
451+
// Hidden state at current time step
452+
#pragma unroll
453+
for (int x = 0; x < CONFIG_T::n_out; x++) {
454+
hidden_state_temp[x] = hidden_state[i][x];
455+
}
456+
457+
// Do SimpleRNN
458+
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
459+
460+
// Write result
461+
#pragma unroll
462+
for (int x = 0; x < CONFIG_T::n_out; x++) {
463+
hidden_state[i + 1][x] = h[x];
464+
}
465+
}
466+
467+
if (CONFIG_T::return_sequences == 0) {
468+
// Output when return_sequences is false
469+
#pragma unroll
470+
for (int x = 0; x < CONFIG_T::n_out; x++) {
471+
res[x] = hidden_state[CONFIG_T::n_timesteps][x];
472+
}
473+
} else {
474+
// Output when return_sequences is true
475+
#pragma unroll
476+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
477+
#pragma unroll
478+
for (int h = 0; h < CONFIG_T::n_out; h++) {
479+
res[x * CONFIG_T::n_out + h] = hidden_state[x + 1][h];
480+
}
481+
}
482+
}
483+
}
484+
485+
template <class data_T, class h_T, class res_T, typename CONFIG_T>
486+
void simple_rnn_pytorch_init_state(const data_T &data, const h_T &hin, res_T &res, const typename CONFIG_T::weight_t &kernel,
487+
const typename CONFIG_T::recurrent_weight_t &rec_kernel,
488+
const typename CONFIG_T::bias_t &bias,
489+
const typename CONFIG_T::recurrent_bias_t &rec_bias) {
490+
491+
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
492+
493+
[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
494+
[[intel::fpga_register]] h_T hidden_state_temp;
495+
[[intel::fpga_register]] h_T h;
496+
[[intel::fpga_register]] in_T in;
497+
498+
// Set initially hidden state (output) to zero
499+
INIT_LOOP:
500+
#pragma unroll
501+
for (int x = 0; x < CONFIG_T::n_out; x++) {
502+
hidden_state[0][x] = hin[x];
441503
}
442504

443505
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
@@ -451,7 +513,7 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
451513
// Hidden state at current time step
452514
#pragma unroll
453515
for (int x = 0; x < CONFIG_T::n_out; x++) {
454-
hidden_state_temp[x] = hidden_state[x][i];
516+
hidden_state_temp[x] = hidden_state[i][x];
455517
}
456518

457519
// Do SimpleRNN
@@ -460,23 +522,23 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
460522
// Write result
461523
#pragma unroll
462524
for (int x = 0; x < CONFIG_T::n_out; x++) {
463-
hidden_state[x][i + 1] = h[x];
525+
hidden_state[i + 1][x] = h[x];
464526
}
465527
}
466528

467529
if (CONFIG_T::return_sequences == 0) {
468530
// Output when return_sequences is false
469531
#pragma unroll
470532
for (int x = 0; x < CONFIG_T::n_out; x++) {
471-
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
533+
res[x] = hidden_state[CONFIG_T::n_timesteps][x];
472534
}
473535
} else {
474536
// Output when return_sequences is true
475537
#pragma unroll
476538
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
477539
#pragma unroll
478540
for (int h = 0; h < CONFIG_T::n_out; h++) {
479-
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
541+
res[x * CONFIG_T::n_out + h] = hidden_state[x + 1][h];
480542
}
481543
}
482544
}

hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,68 @@ void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in],
490490
}
491491
}
492492

493+
template <class data_T, class data2_T, class res_T, typename CONFIG_T>
494+
void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], data2_T hin[CONFIG_T::n_out],
495+
res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
496+
const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
497+
const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
498+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
499+
const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
500+
data2_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
501+
data2_T hidden_state_temp[CONFIG_T::n_out] hls_register;
502+
data2_T h[CONFIG_T::n_out] hls_register;
503+
data_T in[CONFIG_T::n_in] hls_register;
504+
505+
// Set initially hidden state (output) to zero
506+
INIT_LOOP:
507+
#pragma unroll
508+
for (int x = 0; x < CONFIG_T::n_out; x++) {
509+
hidden_state[x][0] = hin[x];
510+
}
511+
512+
#pragma disable_loop_pipelining
513+
for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
514+
515+
// Data at current time step
516+
#pragma unroll
517+
for (int x = 0; x < CONFIG_T::n_in; x++) {
518+
in[x] = data[x + i * CONFIG_T::n_in];
519+
}
520+
521+
// Hidden state at current time step
522+
#pragma unroll
523+
for (int x = 0; x < CONFIG_T::n_out; x++) {
524+
hidden_state_temp[x] = hidden_state[x][i];
525+
}
526+
527+
// Do SimpleRNN
528+
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
529+
530+
// Write result
531+
#pragma unroll
532+
for (int x = 0; x < CONFIG_T::n_out; x++) {
533+
hidden_state[x][i + 1] = h[x];
534+
}
535+
}
536+
537+
if (CONFIG_T::return_sequences == 0) {
538+
// Output when return_sequences is false
539+
#pragma unroll
540+
for (int x = 0; x < CONFIG_T::n_out; x++) {
541+
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
542+
}
543+
} else {
544+
// Output when return_sequences is true
545+
#pragma unroll
546+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
547+
#pragma unroll
548+
for (int h = 0; h < CONFIG_T::n_out; h++) {
549+
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
550+
}
551+
}
552+
}
553+
}
554+
493555
//----------------------
494556
// LSTM
495557
//----------------------

test/pytest/test_recurrent_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(self, x, h0):
171171
return output
172172

173173

174-
@pytest.mark.parametrize('backend', ['Quartus'])
174+
@pytest.mark.parametrize('backend', ['Quartus', 'oneAPI'])
175175
@pytest.mark.parametrize('io_type', ['io_parallel'])
176176
def test_rnn(backend, io_type):
177177
if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"):

0 commit comments

Comments
 (0)