@@ -12,7 +12,10 @@ namespace nnet {
1212struct lstm_config {
1313 // Internal data type definitions
1414 typedef float weight_t ;
15+ typedef float recurrent_weight_t ;
1516 typedef float bias_t ;
17+ typedef float recurrent_bias_t ;
18+ typedef float accum_t ;
1619
1720 // Layer Sizes
1821 static const unsigned n_in = 2 ;
@@ -47,9 +50,9 @@ struct lstm_config {
4750template <class data_T , class res_T , typename CONFIG_T>
4851void lstm (bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
4952 res_T s_newstate[CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
50- typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
53+ typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
5154 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4 ],
52- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4 ]) {
55+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4 ]) {
5356 // Initialize the state variable -- will maintain state between function calls
5457
5558 typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 4 ];
@@ -86,20 +89,20 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG
8689 inputacc_c[iacc] = tmpres[index] + tmpres_state[index];
8790 }
8891
89- CONFIG_T::template activation_recr<data_T, typename CONFIG_T::weight_t , typename CONFIG_T::ACT_CONFIG_LSTM>:: activation (
90- inputacc_ifo, tmpres_ifo);
92+ CONFIG_T::template activation_recr<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
93+ typename CONFIG_T::ACT_CONFIG_LSTM>:: activation ( inputacc_ifo, tmpres_ifo);
9194
9295 // Now for the confusion matrix
93- CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t , typename CONFIG_T::ACT_CONFIG_T>:: activation (
94- inputacc_c, tmpres_c);
96+ CONFIG_T::template activation<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
97+ typename CONFIG_T::ACT_CONFIG_T>:: activation ( inputacc_c, tmpres_c);
9598
9699 // Operation: s=g*i+sold*f (update state with buffer to avoid timing issues)
97100 for (int iacc = 0 ; iacc < (CONFIG_T::n_state); iacc++) {
98101 #pragma HLS UNROLL
99102 s_newstate[iacc] = tmpres_c[iacc] * tmpres_ifo[iacc] + s_newstate[iacc] * tmpres_ifo[iacc + (CONFIG_T::n_state)];
100103 }
101104 // Operation: h=act(s)*o
102- CONFIG_T::template activation<data_T , typename CONFIG_T::weight_t , typename CONFIG_T::ACT_CONFIG_T>::activation (
105+ CONFIG_T::template activation<res_T , typename CONFIG_T::accum_t , typename CONFIG_T::ACT_CONFIG_T>::activation (
103106 s_newstate, s_actstate);
104107
105108 for (int iacc = 0 ; iacc < CONFIG_T::n_state; iacc++) {
@@ -112,9 +115,9 @@ template <class data_T, class res_T, typename CONFIG_T>
112115void lstm_static (bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
113116 res_T s_newstate[CONFIG_T::n_state],
114117 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
115- typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
118+ typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
116119 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4 ],
117- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4 ]) {
120+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4 ]) {
118121 static res_T h_state[CONFIG_T::n_state];
119122 static res_T s_state[CONFIG_T::n_state];
120123 // Initialize the state variable -- will maintain state between function calls
@@ -163,12 +166,12 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
163166 inputacc_c[iacc] = tmpres[index] + tmpres_state[index];
164167 }
165168
166- CONFIG_T::template activation_recr<data_T, typename CONFIG_T::weight_t , typename CONFIG_T::ACT_CONFIG_LSTM>:: activation (
167- inputacc_ifo, tmpres_ifo);
169+ CONFIG_T::template activation_recr<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
170+ typename CONFIG_T::ACT_CONFIG_LSTM>:: activation ( inputacc_ifo, tmpres_ifo);
168171
169172 // Now for the confusion matrix
170- CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t , typename CONFIG_T::ACT_CONFIG_T>:: activation (
171- inputacc_c, tmpres_c);
173+ CONFIG_T::template activation<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
174+ typename CONFIG_T::ACT_CONFIG_T>:: activation ( inputacc_c, tmpres_c);
172175
173176 // Operation: s=g*i+sold*f (update state with buffer to avoid timing issues)
174177 for (int iacc = 0 ; iacc < (CONFIG_T::n_state); iacc++) {
@@ -177,7 +180,7 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
177180 s_newstate[iacc] = s_state[iacc];
178181 }
179182 // Operation: h=act(s)*o
180- CONFIG_T::template activation<data_T , typename CONFIG_T::weight_t , typename CONFIG_T::ACT_CONFIG_T>::activation (
183+ CONFIG_T::template activation<res_T , typename CONFIG_T::accum_t , typename CONFIG_T::ACT_CONFIG_T>::activation (
181184 s_state, s_actstate);
182185
183186 for (int iacc = 0 ; iacc < CONFIG_T::n_state; iacc++) {
@@ -190,9 +193,9 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
190193template <class data_T , class res_T , typename CONFIG_T>
191194void lstm_stack (data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state],
192195 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
193- typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
196+ typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
194197 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4 ],
195- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4 ]) {
198+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4 ]) {
196199
197200 res_T h_newstate[CONFIG_T::n_state];
198201 res_T s_newstate[CONFIG_T::n_state];
@@ -235,9 +238,9 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CO
235238template <class data_T , class res_T , typename CONFIG_T>
236239void lstm_stack (hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream,
237240 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
238- typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
241+ typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
239242 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4 ],
240- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4 ]) {
243+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4 ]) {
241244
242245 typename res_T::value_type h_newstate[CONFIG_T::n_state];
243246 typename res_T::value_type s_newstate[CONFIG_T::n_state];
@@ -300,7 +303,9 @@ void lstm_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream
300303struct gru_config {
301304 // Internal data type definitions
302305 typedef float weight_t ;
306+ typedef float recurrent_weight_t ;
303307 typedef float bias_t ;
308+ typedef float recurrent_bias_t ;
304309 typedef float accum_t ;
305310
306311 // Layer Sizes
@@ -327,9 +332,9 @@ template <class data_T, class res_T, typename CONFIG_T>
327332void gru (bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
328333 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], // TODO - Check the layout of the param
329334 // weights - refer page in copy!!
330- typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
335+ typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
331336 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3 ],
332- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3 ]) {
337+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3 ]) {
333338 // Initialize the state variable -- will maintain state between function calls
334339 typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 3 ];
335340 typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state * 3 ];
@@ -361,7 +366,7 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
361366 }
362367
363368 // Activation function Sub layer -- START
364- CONFIG_T::template activation_recr<typename CONFIG_T::accum_t , typename CONFIG_T::weight_t ,
369+ CONFIG_T::template activation_recr<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
365370 typename CONFIG_T::ACT_CONFIG_GRU>::activation (inputacc_zr, tmpres_zr);
366371
367372 // Activation function Sub layer -- END
@@ -383,7 +388,7 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
383388 }
384389
385390 // Now run the activation on this guy
386- CONFIG_T::template activation<typename CONFIG_T::accum_t , typename CONFIG_T::weight_t ,
391+ CONFIG_T::template activation<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
387392 typename CONFIG_T::ACT_CONFIG_T>::activation (inputacc_h, tmpres_h);
388393
389394 // Mix the stat with the previous state
@@ -400,9 +405,9 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
400405template <class data_T , class res_T , typename CONFIG_T>
401406void gru_static (bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
402407 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
403- typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
408+ typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
404409 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3 ],
405- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3 ]) {
410+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3 ]) {
406411 // Initialize the state variable -- will maintain state between function calls
407412
408413 static res_T h_state[CONFIG_T::n_state];
@@ -444,7 +449,7 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
444449 }
445450
446451 // Activation function Sub layer -- START
447- CONFIG_T::template activation_recr<typename CONFIG_T::accum_t , typename CONFIG_T::weight_t ,
452+ CONFIG_T::template activation_recr<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
448453 typename CONFIG_T::ACT_CONFIG_GRU>::activation (inputacc_zr, tmpres_zr);
449454
450455 // Activation function Sub layer -- END
@@ -466,7 +471,7 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
466471 }
467472
468473 // Now run the activation on this guy
469- CONFIG_T::template activation<typename CONFIG_T::accum_t , typename CONFIG_T::weight_t ,
474+ CONFIG_T::template activation<typename CONFIG_T::accum_t , typename CONFIG_T::accum_t ,
470475 typename CONFIG_T::ACT_CONFIG_T>::activation (inputacc_h, tmpres_h);
471476
472477 // Mix the stat with the previous state
@@ -484,9 +489,9 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
484489template <class data_T , class res_T , typename CONFIG_T>
485490void gru_stack (data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state],
486491 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
487- typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
492+ typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
488493 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3 ],
489- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3 ]) {
494+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3 ]) {
490495
491496 res_T h_state[CONFIG_T::n_state];
492497 data_T data_in[CONFIG_T::n_in];
@@ -525,9 +530,9 @@ void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CON
525530template <class data_T , class res_T , typename CONFIG_T>
526531void gru_stack (hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream,
527532 typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
528- typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
533+ typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
529534 typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3 ],
530- typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3 ]) {
535+ typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3 ]) {
531536
532537 typename res_T::value_type h_newstate[CONFIG_T::n_state];
533538 #pragma HLS ARRAY_PARTITION variable=h_newstate complete
0 commit comments