Skip to content

Commit 90ada94

Browse files
bo3zpre-commit-ci[bot]vloncar
authored
Fixes for quantised RNNs in data type inconsistencies (fastmachinelearning#1171)
* Fixes for quantised RNNs * [pre-commit.ci] auto fixes from pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Vladimir <[email protected]>
1 parent 0261075 commit 90ada94

File tree

3 files changed

+59
-34
lines changed

3 files changed

+59
-34
lines changed

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# recurrent multiplication template
66

7-
recr_mult_config_template = """struct config{index} : nnet::dense_config {{
7+
recr_mult_config_template_1 = """struct config{index} : nnet::dense_config {{
88
static const unsigned n_in = {n_in};
99
static const unsigned n_out = {n_out};
1010
static const unsigned strategy = nnet::{strategy};
@@ -22,6 +22,24 @@
2222
using product = nnet::product::{product_type}<x_T, y_T>;
2323
}};\n"""
2424

25+
recr_mult_config_template_2 = """struct config{index} : nnet::dense_config {{
26+
static const unsigned n_in = {n_in};
27+
static const unsigned n_out = {n_out};
28+
static const unsigned strategy = nnet::{strategy};
29+
static const unsigned reuse_factor = {reuse};
30+
static const unsigned n_zeros = {nzeros};
31+
static const unsigned n_nonzeros = {nonzeros};
32+
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor;
33+
static const bool store_weights_in_bram = false;
34+
typedef {accum_t.name} accum_t;
35+
typedef {recurrent_bias_t.name} bias_t;
36+
typedef {recurrent_weight_t.name} weight_t;
37+
template<class data_T, class res_T, class CONFIG_T>
38+
using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>;
39+
template<class x_T, class y_T>
40+
using product = nnet::product::{product_type}<x_T, y_T>;
41+
}};\n"""
42+
2543
# activation templates
2644

2745
activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
@@ -45,7 +63,9 @@
4563
recr_config_template = """struct config{index} : nnet::{recr_type}_config {{
4664
typedef {accum_t.name} accum_t;
4765
typedef {weight_t.name} weight_t; // Matrix
66+
typedef {recurrent_weight_t.name} recurrent_weight_t; // Matrix
4867
typedef {bias_t.name} bias_t; // Vector
68+
typedef {recurrent_bias_t.name} recurrent_bias_t; // Vector
4969
typedef {config_mult_t1} mult_config1;
5070
typedef {config_mult_t2} mult_config2;
5171
typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE};
@@ -77,8 +97,8 @@ def __init__(self):
7797
self.template = recr_config_template
7898
self.act_template = activ_config_template
7999
self.recr_act_template = recr_activ_config_template
80-
self.mult1_template = recr_mult_config_template
81-
self.mult2_template = recr_mult_config_template
100+
self.mult1_template = recr_mult_config_template_1
101+
self.mult2_template = recr_mult_config_template_2
82102

83103
def format(self, node):
84104
params = self._default_config_params(node)

hls4ml/converters/keras/qkeras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def parse_qrnn_layer(keras_layer, input_names, input_shapes, data_reader):
7878
layer, output_shape = parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader)
7979

8080
layer['weight_quantizer'] = get_quantizer_from_config(keras_layer, 'kernel')
81-
layer['recurrent_quantizer'] = get_quantizer_from_config(keras_layer, 'recurrent')
81+
layer['recurrent_weight_quantizer'] = get_quantizer_from_config(keras_layer, 'recurrent')
8282
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')
8383

8484
return layer, output_shape

hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ namespace nnet {
1212
struct lstm_config {
1313
// Internal data type definitions
1414
typedef float weight_t;
15+
typedef float recurrent_weight_t;
1516
typedef float bias_t;
17+
typedef float recurrent_bias_t;
18+
typedef float accum_t;
1619

1720
// Layer Sizes
1821
static const unsigned n_in = 2;
@@ -47,9 +50,9 @@ struct lstm_config {
4750
template <class data_T, class res_T, typename CONFIG_T>
4851
void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
4952
res_T s_newstate[CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
50-
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
53+
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
5154
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
52-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
55+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {
5356
// Initialize the state variable -- will maintain state between function calls
5457

5558
typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 4];
@@ -86,20 +89,20 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG
8689
inputacc_c[iacc] = tmpres[index] + tmpres_state[index];
8790
}
8891

89-
CONFIG_T::template activation_recr<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_LSTM>::activation(
90-
inputacc_ifo, tmpres_ifo);
92+
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
93+
typename CONFIG_T::ACT_CONFIG_LSTM>::activation(inputacc_ifo, tmpres_ifo);
9194

9295
// Now for the confusion matrix
93-
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
94-
inputacc_c, tmpres_c);
96+
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
97+
typename CONFIG_T::ACT_CONFIG_T>::activation(inputacc_c, tmpres_c);
9598

9699
// Operation: s=g*i+sold*f (update state with buffer to avoid timing issues)
97100
for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) {
98101
#pragma HLS UNROLL
99102
s_newstate[iacc] = tmpres_c[iacc] * tmpres_ifo[iacc] + s_newstate[iacc] * tmpres_ifo[iacc + (CONFIG_T::n_state)];
100103
}
101104
// Operation: h=act(s)*o
102-
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
105+
CONFIG_T::template activation<res_T, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
103106
s_newstate, s_actstate);
104107

105108
for (int iacc = 0; iacc < CONFIG_T::n_state; iacc++) {
@@ -112,9 +115,9 @@ template <class data_T, class res_T, typename CONFIG_T>
112115
void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
113116
res_T s_newstate[CONFIG_T::n_state],
114117
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
115-
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
118+
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
116119
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
117-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
120+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {
118121
static res_T h_state[CONFIG_T::n_state];
119122
static res_T s_state[CONFIG_T::n_state];
120123
// Initialize the state variable -- will maintain state between function calls
@@ -163,12 +166,12 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
163166
inputacc_c[iacc] = tmpres[index] + tmpres_state[index];
164167
}
165168

166-
CONFIG_T::template activation_recr<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_LSTM>::activation(
167-
inputacc_ifo, tmpres_ifo);
169+
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
170+
typename CONFIG_T::ACT_CONFIG_LSTM>::activation(inputacc_ifo, tmpres_ifo);
168171

169172
// Now for the confusion matrix
170-
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
171-
inputacc_c, tmpres_c);
173+
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
174+
typename CONFIG_T::ACT_CONFIG_T>::activation(inputacc_c, tmpres_c);
172175

173176
// Operation: s=g*i+sold*f (update state with buffer to avoid timing issues)
174177
for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) {
@@ -177,7 +180,7 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
177180
s_newstate[iacc] = s_state[iacc];
178181
}
179182
// Operation: h=act(s)*o
180-
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
183+
CONFIG_T::template activation<res_T, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
181184
s_state, s_actstate);
182185

183186
for (int iacc = 0; iacc < CONFIG_T::n_state; iacc++) {
@@ -190,9 +193,9 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
190193
template <class data_T, class res_T, typename CONFIG_T>
191194
void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state],
192195
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
193-
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
196+
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
194197
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
195-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
198+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {
196199

197200
res_T h_newstate[CONFIG_T::n_state];
198201
res_T s_newstate[CONFIG_T::n_state];
@@ -235,9 +238,9 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CO
235238
template <class data_T, class res_T, typename CONFIG_T>
236239
void lstm_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream,
237240
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
238-
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
241+
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
239242
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
240-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
243+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {
241244

242245
typename res_T::value_type h_newstate[CONFIG_T::n_state];
243246
typename res_T::value_type s_newstate[CONFIG_T::n_state];
@@ -300,7 +303,9 @@ void lstm_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream
300303
struct gru_config {
301304
// Internal data type definitions
302305
typedef float weight_t;
306+
typedef float recurrent_weight_t;
303307
typedef float bias_t;
308+
typedef float recurrent_bias_t;
304309
typedef float accum_t;
305310

306311
// Layer Sizes
@@ -327,9 +332,9 @@ template <class data_T, class res_T, typename CONFIG_T>
327332
void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
328333
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], // TODO - Check the layout of the param
329334
// weights - refer page in copy!!
330-
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
335+
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
331336
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
332-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
337+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {
333338
// Initialize the state variable -- will maintain state between function calls
334339
typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 3];
335340
typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state * 3];
@@ -361,7 +366,7 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
361366
}
362367

363368
// Activation function Sub layer -- START
364-
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
369+
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
365370
typename CONFIG_T::ACT_CONFIG_GRU>::activation(inputacc_zr, tmpres_zr);
366371

367372
// Activation function Sub layer -- END
@@ -383,7 +388,7 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
383388
}
384389

385390
// Now run the activation on this guy
386-
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
391+
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
387392
typename CONFIG_T::ACT_CONFIG_T>::activation(inputacc_h, tmpres_h);
388393

389394
// Mix the stat with the previous state
@@ -400,9 +405,9 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
400405
template <class data_T, class res_T, typename CONFIG_T>
401406
void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
402407
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
403-
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
408+
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
404409
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
405-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
410+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {
406411
// Initialize the state variable -- will maintain state between function calls
407412

408413
static res_T h_state[CONFIG_T::n_state];
@@ -444,7 +449,7 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
444449
}
445450

446451
// Activation function Sub layer -- START
447-
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
452+
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
448453
typename CONFIG_T::ACT_CONFIG_GRU>::activation(inputacc_zr, tmpres_zr);
449454

450455
// Activation function Sub layer -- END
@@ -466,7 +471,7 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
466471
}
467472

468473
// Now run the activation on this guy
469-
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
474+
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
470475
typename CONFIG_T::ACT_CONFIG_T>::activation(inputacc_h, tmpres_h);
471476

472477
// Mix the stat with the previous state
@@ -484,9 +489,9 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
484489
template <class data_T, class res_T, typename CONFIG_T>
485490
void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state],
486491
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
487-
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
492+
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
488493
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
489-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
494+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {
490495

491496
res_T h_state[CONFIG_T::n_state];
492497
data_T data_in[CONFIG_T::n_in];
@@ -525,9 +530,9 @@ void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CON
525530
template <class data_T, class res_T, typename CONFIG_T>
526531
void gru_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream,
527532
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
528-
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
533+
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
529534
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
530-
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
535+
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {
531536

532537
typename res_T::value_type h_newstate[CONFIG_T::n_state];
533538
#pragma HLS ARRAY_PARTITION variable=h_newstate complete

0 commit comments

Comments
 (0)