Skip to content

Commit 554c7ad

Browse files
authored
Merge pull request #6 from jmitrevs/initialRecurr_fix_pytorch_oneAPI
fix pytorch_order for GRU, recurrent bias for simpleNN, oneAPI
2 parents b063de6 + 767a5f8 commit 554c7ad

File tree

2 files changed

+146
-4
lines changed

2 files changed

+146
-4
lines changed

hls4ml/backends/oneapi/passes/recurrent_templates.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
9393
9494
static const unsigned reuse_factor = {reuse};
95+
static const unsigned pytorch_order = {pytorch};
9596
static const bool store_weights_in_bram = false;
9697
}};\n'''
9798

@@ -123,6 +124,7 @@ def format(self, node):
123124
params['config_mult_h'] = f'config{node.index}_h_mult'
124125
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
125126
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
127+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
126128
gru_config = self.gru_template.format(**params)
127129

128130
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -345,6 +347,9 @@ def format(self, node):
345347
}};\n"""
346348

347349
simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
350+
simple_rnn_pytorch_function_template = (
351+
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
352+
)
348353

349354

350355
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -390,5 +395,10 @@ def __init__(self):
390395

391396
def format(self, node):
392397
params = self._default_function_params(node)
393-
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
398+
if node.get_attr('pytorch', False):
399+
self.template = simple_rnn_pytorch_function_template
400+
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
401+
else:
402+
self.template = simple_rnn_function_template
403+
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
394404
return self.template.format(**params)

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ struct gru_config {
8989
// Resource reuse info
9090
static const unsigned io_type = io_parallel;
9191
static const unsigned reuse_factor = 1;
92+
static const bool pytorch_order = false;
9293
static const bool store_weights_in_bram = false;
9394

9495
// Activation
@@ -137,7 +138,10 @@ void gru_cell(const data_T &x, h_T &h, const typename CONFIG_T::weight_t &weight
137138
[[intel::fpga_register]] h_activ_array_T hadamard_r_h;
138139
#pragma unroll recurrent_unroll_factor
139140
for (int i = 0; i < (CONFIG_T::n_units); i++) {
140-
hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
141+
if (CONFIG_T::pytorch_order)
142+
hadamard_r_h[i] = z_r_act[i] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
143+
else
144+
hadamard_r_h[i] = z_r_act[i + CONFIG_T::n_units] * mat_mul_h_wr[i + 2 * CONFIG_T::n_units];
141145
}
142146

143147
// The candidate state; X * W_{hx} + hadmard(r(t), h_(t-1)) * W_{hh} + b_{h}
@@ -156,7 +160,11 @@ void gru_cell(const data_T &x, h_T &h, const typename CONFIG_T::weight_t &weight
156160
// Update state
157161
#pragma unroll recurrent_unroll_factor
158162
for (int i = 0; i < (CONFIG_T::n_units); i++) {
159-
h[i] = static_cast<typename h_T::value_type>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
163+
if (CONFIG_T::pytorch_order)
164+
h[i] = static_cast<typename h_T::value_type>(h_cand_act[i] * (1 - z_r_act[i + CONFIG_T::n_units]) +
165+
h[i] * z_r_act[i + CONFIG_T::n_units]);
166+
else
167+
h[i] = static_cast<typename h_T::value_type>(h_cand_act[i] * (1 - z_r_act[i]) + h[i] * z_r_act[i]);
160168
}
161169
}
162170

@@ -328,7 +336,7 @@ void simple_rnn(const data_T &data, res_T &res, const typename CONFIG_T::weight_
328336
// Write result
329337
#pragma unroll
330338
for (int x = 0; x < CONFIG_T::n_out; x++) {
331-
hidden_state[i + 1][x] = h[x];
339+
hidden_state[x][i + 1] = h[x];
332340
}
333341
}
334342

@@ -350,6 +358,130 @@ void simple_rnn(const data_T &data, res_T &res, const typename CONFIG_T::weight_
350358
}
351359
}
352360

361+
//----------------------
362+
// SimpleRNN with pytorch biases
363+
//----------------------
364+
365+
struct simpleRNN_pytorch_config {
366+
// Internal data type definitions
367+
typedef float weight_t;
368+
typedef float bias_t;
369+
typedef float accum_t;
370+
371+
// Layer Sizes
372+
static const unsigned n_in = 1;
373+
static const unsigned n_out = 1;
374+
static const unsigned n_outputs = 1;
375+
static const unsigned n_timesteps = 1;
376+
static const bool return_sequences = false;
377+
378+
// Resource reuse info
379+
static const unsigned io_type = io_parallel;
380+
static const unsigned reuse_factor = 1;
381+
static const bool store_weights_in_bram = false;
382+
383+
// Activation
384+
template <class x_T, class y_T, class config_T> using activation_recr = nnet::activation::relu<x_T, y_T, config_T>;
385+
386+
template <class x_T, class y_T, class config_T> using activation = nnet::activation::relu<x_T, y_T, config_T>;
387+
};
388+
389+
template <class in_T, class h_T, typename CONFIG_T>
390+
void simple_rnn_pytorch_cell(const in_T &inputs, h_T &hidden_state, h_T &hidden_state_o,
391+
const typename CONFIG_T::weight_t &kernel,
392+
const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
393+
const typename CONFIG_T::recurrent_bias_t rec_bias) {
394+
395+
using accum_array_T = array<typename CONFIG_T::accum_t, CONFIG_T::n_out>;
396+
397+
// Weight multiplication
398+
[[intel::fpga_register]] accum_array_T afterW;
399+
multiply_W<in_T, accum_array_T, typename CONFIG_T::weight_t, CONFIG_T::n_in, CONFIG_T::n_out>(inputs, afterW, kernel);
400+
401+
// Bias addition
402+
[[intel::fpga_register]] accum_array_T afterBias;
403+
add_bias<accum_array_T, accum_array_T, typename CONFIG_T::bias_t, CONFIG_T::n_out>(afterW, afterBias, bias);
404+
405+
// Hidden state
406+
[[intel::fpga_register]] accum_array_T hiddenCand;
407+
multiply_U<in_T, accum_array_T, typename CONFIG_T::recurrent_weight_t, CONFIG_T::n_out>(hidden_state, hiddenCand,
408+
rec_kernel);
409+
410+
// Hidden state bias addition
411+
[[intel::fpga_register]] accum_array_T hiddenBias;
412+
add_bias<accum_array_T, accum_array_T, typename CONFIG_T::recurrent_bias_t, CONFIG_T::n_out>(hiddenCand, hiddenBias,
413+
rec_bias);
414+
415+
// Vector addition
416+
[[intel::fpga_register]] accum_array_T afterAdd;
417+
add_vectors<accum_array_T, accum_array_T, accum_array_T, CONFIG_T::n_out>(afterBias, hiddenBias, afterAdd);
418+
419+
// Activation
420+
CONFIG_T::template activation<accum_array_T, h_T, typename CONFIG_T::ACT_CONFIG_T>::activation(afterAdd, hidden_state_o);
421+
}
422+
423+
template <class data_T, class res_T, typename CONFIG_T>
424+
void simple_rnn_pytorch(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &kernel,
425+
const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias,
426+
const typename CONFIG_T::recurrent_bias_t &rec_bias) {
427+
428+
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
429+
using h_T = array<typename res_T::value_type, CONFIG_T::n_out>;
430+
431+
[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
432+
[[intel::fpga_register]] h_T hidden_state_temp;
433+
[[intel::fpga_register]] h_T h;
434+
[[intel::fpga_register]] in_T in;
435+
436+
// Set initially hidden state (output) to zero
437+
INIT_LOOP:
438+
#pragma unroll
439+
for (int x = 0; x < CONFIG_T::n_out; x++) {
440+
hidden_state[x][0] = 0;
441+
}
442+
443+
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
444+
445+
// Data at current time step
446+
#pragma unroll
447+
for (int x = 0; x < CONFIG_T::n_in; x++) {
448+
in[x] = data[x + i * CONFIG_T::n_in];
449+
}
450+
451+
// Hidden state at current time step
452+
#pragma unroll
453+
for (int x = 0; x < CONFIG_T::n_out; x++) {
454+
hidden_state_temp[x] = hidden_state[x][i];
455+
}
456+
457+
// Do SimpleRNN
458+
simple_rnn_pytorch_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias);
459+
460+
// Write result
461+
#pragma unroll
462+
for (int x = 0; x < CONFIG_T::n_out; x++) {
463+
hidden_state[x][i + 1] = h[x];
464+
}
465+
}
466+
467+
if (CONFIG_T::return_sequences == 0) {
468+
// Output when return_sequences is false
469+
#pragma unroll
470+
for (int x = 0; x < CONFIG_T::n_out; x++) {
471+
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
472+
}
473+
} else {
474+
// Output when return_sequences is true
475+
#pragma unroll
476+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
477+
#pragma unroll
478+
for (int h = 0; h < CONFIG_T::n_out; h++) {
479+
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
480+
}
481+
}
482+
}
483+
}
484+
353485
//----------------------
354486
// LSTM
355487
//----------------------

0 commit comments

Comments
 (0)