Skip to content

Commit 77e78e8

Browse files
committed
snapshot that compiles but fails pytests
1 parent cf9c726 commit 77e78e8

File tree

5 files changed

+153
-9
lines changed

5 files changed

+153
-9
lines changed

hls4ml/backends/oneapi/passes/recurrent_templates.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def format(self, node):
333333
typedef {weight_t.name} weight_t;
334334
typedef {bias_t.name} bias_t;
335335
typedef {recurrent_weight_t.name} recurrent_weight_t;
336+
typedef {recurrent_bias_t.name} recurrent_bias_t;
336337
337338
typedef {act_t} ACT_CONFIG_T;
338339
template<class x_T, class y_T, class config_T>
@@ -350,6 +351,9 @@ def format(self, node):
350351
simple_rnn_pytorch_function_template = (
351352
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
352353
)
354+
simple_rnn_pytorch_function_initial_state_template = (
355+
'nnet::simple_rnn_pytorch_init_state<{input_t}, {h_t}, {output_t}, {config}>({input}, {init_state}, {output}, {weights});'
356+
)
353357

354358

355359
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -395,10 +399,17 @@ def __init__(self):
395399

396400
def format(self, node):
397401
params = self._default_function_params(node)
402+
if params['pass_initial_states'] == 'true':
403+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
404+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
405+
398406
if node.get_attr('pytorch', False):
399-
self.template = simple_rnn_pytorch_function_template
407+
if params['pass_initial_states'] == 'true':
408+
template = simple_rnn_pytorch_function_initial_state_template
409+
else:
410+
template = simple_rnn_pytorch_function_template
400411
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
401412
else:
402-
self.template = simple_rnn_function_template
413+
template = simple_rnn_function_template
403414
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
404-
return self.template.format(**params)
415+
return template.format(**params)

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ def format(self, node):
285285
simple_rnn_pytorch_function_template = (
286286
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
287287
)
288+
simple_rnn_pytorch_function_initial_state_template = (
289+
'nnet::simple_rnn_pytorch<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {weights});'
290+
)
288291

289292

290293
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -326,13 +329,20 @@ def format(self, node):
326329
class SimpleRNNFunctionTemplate(FunctionCallTemplate):
327330
def __init__(self):
328331
super().__init__(SimpleRNN, include_header=recurrent_include_list)
329-
self.template = simple_rnn_function_template
330332

331333
def format(self, node):
332334
params = self._default_function_params(node)
335+
if params['pass_initial_states'] == 'true':
336+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
337+
params['input2'] = node.get_input_variable(node.inputs[1]).name
338+
333339
if node.get_attr('pytorch', False):
334-
self.template = simple_rnn_pytorch_function_template
340+
if params['pass_initial_states'] == 'true':
341+
template = simple_rnn_pytorch_function_initial_state_template
342+
else:
343+
template = simple_rnn_pytorch_function_template
335344
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
336345
else:
346+
template = simple_rnn_function_template
337347
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
338-
return self.template.format(**params)
348+
return template.format(**params)

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ void simple_rnn_pytorch_cell(const in_T &inputs, h_T &hidden_state, h_T &hidden_
404404

405405
// Hidden state
406406
[[intel::fpga_register]] accum_array_T hiddenCand;
407-
multiply_U<in_T, accum_array_T, typename CONFIG_T::recurrent_weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
407+
multiply_U<h_T, accum_array_T, typename CONFIG_T::recurrent_weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
408408
rec_kernel);
409409

410410
// Hidden state bias addition
@@ -460,7 +460,68 @@ void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T:
460460
// Write result
461461
#pragma unroll
462462
for (int x = 0; x < CONFIG_T::n_out; x++) {
463-
hidden_state[x][i + 1] = h[x];
463+
hidden_state[i + 1][x] = h[x];
464+
}
465+
}
466+
467+
if (CONFIG_T::return_sequences == 0) {
468+
// Output when return_sequences is false
469+
#pragma unroll
470+
for (int x = 0; x < CONFIG_T::n_out; x++) {
471+
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
472+
}
473+
} else {
474+
// Output when return_sequences is true
475+
#pragma unroll
476+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
477+
#pragma unroll
478+
for (int h = 0; h < CONFIG_T::n_out; h++) {
479+
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
480+
}
481+
}
482+
}
483+
}
484+
485+
template <class data_T, class h_T, class res_T, typename CONFIG_T>
486+
void simple_rnn_pytorch_init_state(const data_T &data, const h_T& hin, res_T &res, const typename CONFIG_T::weight_t &kernel,
487+
const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
488+
const typename CONFIG_T::recurrent_bias_t &rec_bias) {
489+
490+
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
491+
492+
[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
493+
[[intel::fpga_register]] h_T hidden_state_temp;
494+
[[intel::fpga_register]] h_T h;
495+
[[intel::fpga_register]] in_T in;
496+
497+
// Set initially hidden state (output) to zero
498+
INIT_LOOP:
499+
#pragma unroll
500+
for (int x = 0; x < CONFIG_T::n_out; x++) {
501+
hidden_state[x][0] = hin[x];
502+
}
503+
504+
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
505+
506+
// Data at current time step
507+
#pragma unroll
508+
for (int x = 0; x < CONFIG_T::n_in; x++) {
509+
in[x] = data[x + i * CONFIG_T::n_in];
510+
}
511+
512+
// Hidden state at current time step
513+
#pragma unroll
514+
for (int x = 0; x < CONFIG_T::n_out; x++) {
515+
hidden_state_temp[x] = hidden_state[x][i];
516+
}
517+
518+
// Do SimpleRNN
519+
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
520+
521+
// Write result
522+
#pragma unroll
523+
for (int x = 0; x < CONFIG_T::n_out; x++) {
524+
hidden_state[i + 1][x] = h[x];
464525
}
465526
}
466527

hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,68 @@ void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in],
490490
}
491491
}
492492

493+
template <class data_T, class data2_T, class res_T, typename CONFIG_T>
494+
void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], data2_T hin[CONFIG_T::n_out],
495+
res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
496+
const typename CONFIG_T::weight_t kernel[CONFIG_T::n_in * CONFIG_T::n_out],
497+
const typename CONFIG_T::weight_t rec_kernel[CONFIG_T::n_out * CONFIG_T::n_out],
498+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_out],
499+
const typename CONFIG_T::bias_t rec_bias[CONFIG_T::n_out]) {
500+
data2_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
501+
data2_T hidden_state_temp[CONFIG_T::n_out] hls_register;
502+
data2_T h[CONFIG_T::n_out] hls_register;
503+
data_T in[CONFIG_T::n_in] hls_register;
504+
505+
// Set initially hidden state (output) to zero
506+
INIT_LOOP:
507+
#pragma unroll
508+
for (int x = 0; x < CONFIG_T::n_out; x++) {
509+
hidden_state[x][0] = hin[x];
510+
}
511+
512+
#pragma disable_loop_pipelining
513+
for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
514+
515+
// Data at current time step
516+
#pragma unroll
517+
for (int x = 0; x < CONFIG_T::n_in; x++) {
518+
in[x] = data[x + i * CONFIG_T::n_in];
519+
}
520+
521+
// Hidden state at current time step
522+
#pragma unroll
523+
for (int x = 0; x < CONFIG_T::n_out; x++) {
524+
hidden_state_temp[x] = hidden_state[x][i];
525+
}
526+
527+
// Do SimpleRNN
528+
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
529+
530+
// Write result
531+
#pragma unroll
532+
for (int x = 0; x < CONFIG_T::n_out; x++) {
533+
hidden_state[x][i + 1] = h[x];
534+
}
535+
}
536+
537+
if (CONFIG_T::return_sequences == 0) {
538+
// Output when return_sequences is false
539+
#pragma unroll
540+
for (int x = 0; x < CONFIG_T::n_out; x++) {
541+
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
542+
}
543+
} else {
544+
// Output when return_sequences is true
545+
#pragma unroll
546+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
547+
#pragma unroll
548+
for (int h = 0; h < CONFIG_T::n_out; h++) {
549+
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
550+
}
551+
}
552+
}
553+
}
554+
493555
//----------------------
494556
// LSTM
495557
//----------------------

test/pytest/test_recurrent_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(self, x, h0):
171171
return output
172172

173173

174-
@pytest.mark.parametrize('backend', ['Quartus'])
174+
@pytest.mark.parametrize('backend', ['Quartus', 'oneAPI'])
175175
@pytest.mark.parametrize('io_type', ['io_parallel'])
176176
def test_rnn(backend, io_type):
177177
if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"):

0 commit comments

Comments
 (0)