Skip to content

Commit 3bd3687

Browse files
committed
allow initial values for the hidden/cell state to be passed for LSTM and GRU models
1 parent 2c17f66 commit 3bd3687

File tree

8 files changed

+337
-35
lines changed

8 files changed

+337
-35
lines changed

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
}};\n'''
7272

7373
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
74+
gru_function_initial_state_template = (
75+
'nnet::gru<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});'
76+
)
7477

7578

7679
class GRUConfigTemplate(LayerConfigTemplate):
@@ -137,15 +140,23 @@ def format(self, node):
137140
class GRUFunctionTemplate(FunctionCallTemplate):
138141
def __init__(self):
139142
super().__init__(GRU, include_header=recurrent_include_list)
140-
self.template = gru_function_template
141143

142144
def format(self, node):
143145
params = self._default_function_params(node)
146+
if params['pass_initial_states'] == 'true':
147+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
148+
params['input2'] = node.get_input_variable(node.inputs[1]).name
144149
params['w'] = node.get_weights('weight').name
145150
params['b'] = node.get_weights('bias').name
146151
params['wr'] = node.get_weights('recurrent_weight').name
147152
params['br'] = node.get_weights('recurrent_bias').name
148-
return self.template.format(**params)
153+
154+
if params['pass_initial_states'] == 'true':
155+
template = gru_function_initial_state_template
156+
else:
157+
template = gru_function_template
158+
159+
return template.format(**params)
149160

150161

151162
################################################
@@ -174,6 +185,9 @@ def format(self, node):
174185
}};\n"""
175186

176187
lstm_function_template = 'nnet::lstm<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
188+
lstm_function_initial_state_template = (
189+
'nnet::lstm<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {weights});'
190+
)
177191

178192

179193
class LSTMConfigTemplate(LayerConfigTemplate):
@@ -214,11 +228,16 @@ def format(self, node):
214228
class LSTMFunctionTemplate(FunctionCallTemplate):
215229
def __init__(self):
216230
super().__init__(LSTM, include_header=recurrent_include_list)
217-
self.template = lstm_function_template
218231

219232
def format(self, node):
220233
params = self._default_function_params(node)
221234

235+
if params['pass_initial_states'] == 'true':
236+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
237+
params['input2'] = node.get_input_variable(node.inputs[1]).name
238+
params['input3'] = node.get_input_variable(node.inputs[2]).name
239+
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
240+
222241
types = ['i', 'f', 'c', 'o']
223242
params['weights'] = ''
224243
for t in types:
@@ -228,7 +247,12 @@ def format(self, node):
228247
for t in types:
229248
params['weights'] += 'bias_{}_{}{}'.format(t, str(node.index), ',' if t != 'o' else '')
230249

231-
return self.template.format(**params)
250+
if params['pass_initial_states'] == 'true':
251+
template = lstm_function_initial_state_template
252+
else:
253+
template = lstm_function_template
254+
255+
return template.format(**params)
232256

233257

234258
################################################

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
}};\n"""
6868

6969
recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
70+
recr_function_template_initial_states_lstm = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501
71+
recr_function_template_initial_states_gru = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501
7072

7173
recr_include_list = ['nnet_utils/nnet_recurrent.h']
7274

@@ -186,10 +188,16 @@ def format(self, node):
186188
class RecurrentFunctionTemplate(FunctionCallTemplate):
187189
def __init__(self):
188190
super().__init__((LSTM, GRU), include_header=recr_include_list)
189-
self.template = recr_function_template
190191

191192
def format(self, node):
192193
params = self._default_function_params(node)
194+
if params['pass_initial_states'] == 'true':
195+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
196+
params['input2'] = node.get_input_variable(node.inputs[1]).name
197+
if node.class_name == 'LSTM':
198+
params['input3'] = node.get_input_variable(node.inputs[2]).name
199+
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
200+
193201
params['w'] = node.get_weights('weight').name
194202
params['b'] = node.get_weights('bias').name
195203
params['wr'] = node.get_weights('recurrent_weight').name
@@ -198,4 +206,12 @@ def format(self, node):
198206
params['recurrent_activation'] = node.get_attr('recurrent_activation')
199207
params['recr_type'] = node.class_name.lower()
200208

201-
return self.template.format(**params)
209+
if params['pass_initial_states'] == 'true':
210+
if node.class_name == 'LSTM':
211+
template = recr_function_template_initial_states_lstm
212+
else:
213+
template = recr_function_template_initial_states_gru
214+
else:
215+
template = recr_function_template
216+
217+
return template.format(**params)

hls4ml/converters/keras/recurrent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,6 @@ def parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader):
4747
if layer['return_state']:
4848
raise Exception('"return_state" of {} layer is not yet supported.')
4949

50+
layer['pass_initial_states'] = False
51+
5052
return layer, output_shape

hls4ml/converters/pytorch/recurrent.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import warnings
2-
31
import numpy as np
42

53
from hls4ml.converters.pytorch_to_hls import pytorch_handler
@@ -15,14 +13,13 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
1513

1614
layer["name"] = layer_name
1715

18-
layer['inputs'] = [input_names[0]]
19-
if len(input_names) > 1:
20-
warnings.warn(
21-
'hls4ml disregards the initial value of the hidden state passed to the model, assuming that it is all zeros',
22-
stacklevel=2,
23-
)
16+
layer['inputs'] = input_names
17+
if 'IOType' in config.keys():
18+
if len(input_names) > 1 and config['IOType'] == 'io_stream':
19+
raise Exception('Passing initial values for the hidden state is not support for io_stream input type.')
20+
2421
layer['class_name'] = operation
25-
if operation == "RNN":
22+
if operation == 'RNN':
2623
layer['class_name'] = 'SimpleRNN'
2724

2825
layer['return_sequences'] = False # parameter does not exist in pytorch
@@ -31,7 +28,7 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
3128
if layer['class_name'] == 'SimpleRNN':
3229
layer['activation'] = class_object.nonlinearity # Default is tanh, can also be ReLU in pytorch
3330
else:
34-
layer['activation'] = "tanh" # GRU and LSTM are hard-coded to use tanh in pytorch
31+
layer['activation'] = 'tanh' # GRU and LSTM are hard-coded to use tanh in pytorch
3532

3633
if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM':
3734
layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch
@@ -51,7 +48,6 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
5148

5249
if class_object.bidirectional:
5350
raise Exception('hls4ml does not support birectional RNNs')
54-
5551
if class_object.dropout > 0:
5652
raise Exception('hls4ml does not support RNNs with dropout')
5753

@@ -70,5 +66,9 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
7066
output_shape = [input_shapes[0][0], layer['n_out']]
7167

7268
layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations
69+
if len(input_names) == 1:
70+
layer['pass_initial_states'] = False
71+
else:
72+
layer['pass_initial_states'] = True
7373

7474
return layer, output_shape

hls4ml/converters/pytorch_to_hls.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,17 @@ def parse_pytorch_model(config, verbose=True):
203203
# parse info from class object
204204
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
205205
if pytorch_class in ["RNN", "GRU", "LSTM"]:
206-
# we currently don't support the passing of the initial value of the hidden state to RNN models
207-
input_names = [inputs_map.get(str(node.args[0]), str(node.args[0]))]
208-
input_shapes = [output_shapes[str(node.args[0])]]
206+
input_shapes = []
207+
input_names = []
208+
for i in node.args:
209+
if isinstance(i, tuple):
210+
for y in i:
211+
input_shapes.append(output_shapes[str(y)])
212+
input_names.append(inputs_map.get(str(y), str(y)))
213+
else:
214+
input_shapes.append(output_shapes[str(i)])
215+
input_names.append(inputs_map.get(str(i), str(i)))
216+
209217
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
210218
elif "getitem" in node.args[0].name:
211219
for tmp_node in traced_model.graph.nodes:

hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,47 @@ void gru(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_outputs * CONFIG_T::
206206
}
207207
}
208208

209+
template <class data_T, class data2_T, class res_T, typename CONFIG_T>
210+
void gru(data_T data[CONFIG_T::n_in], data2_T h[CONFIG_T::n_units], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_units],
211+
const typename CONFIG_T::weight_t weights[3 * CONFIG_T::n_units * CONFIG_T::n_in],
212+
const typename CONFIG_T::weight_t recurrent_weights[3 * CONFIG_T::n_units * CONFIG_T::n_units],
213+
const typename CONFIG_T::bias_t bias[3 * CONFIG_T::n_units],
214+
const typename CONFIG_T::bias_t recurrent_bias[3 * CONFIG_T::n_units]) {
215+
216+
hls_register data_T x[CONFIG_T::n_in];
217+
// hls_register res_T h[CONFIG_T::n_units];
218+
219+
// #pragma unroll
220+
// for (int i = 0; i < CONFIG_T::n_units; i++) {
221+
// h[i] = 0;
222+
// }
223+
224+
// Loop depedency - cannot pipeline
225+
#pragma disable_loop_pipelining
226+
for (int t = 0; t < CONFIG_T::n_timesteps; t++) {
227+
// Get data at current time step
228+
#pragma unroll
229+
for (int j = 0; j < CONFIG_T::n_in; j++) {
230+
x[j] = data[j + t * CONFIG_T::n_in];
231+
}
232+
233+
nnet::gru_cell<data_T, res_T, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);
234+
235+
if (CONFIG_T::return_sequences) {
236+
#pragma unroll
237+
for (int i = 0; i < CONFIG_T::n_units; i++) {
238+
res[CONFIG_T::n_units * t + i] = h[i];
239+
}
240+
}
241+
}
242+
243+
if (!CONFIG_T::return_sequences) {
244+
#pragma unroll
245+
for (int i = 0; i < (CONFIG_T::n_units); i++) {
246+
res[i] = h[i];
247+
}
248+
}
249+
}
209250
//----------------------
210251
// SimpleRNN
211252
//----------------------
@@ -711,6 +752,79 @@ void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[CONFIG_
711752
}
712753
}
713754

755+
template <class data_T, class data2_T, class data3_T, class res_T, class CONFIG_T>
756+
void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], data2_T hidden_state_initial[CONFIG_T::n_out],
757+
data3_T cell_state_initial[CONFIG_T::n_out], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
758+
const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out],
759+
const typename CONFIG_T::weight_t WF[CONFIG_T::n_in * CONFIG_T::n_out],
760+
const typename CONFIG_T::weight_t WC[CONFIG_T::n_in * CONFIG_T::n_out],
761+
const typename CONFIG_T::weight_t WO[CONFIG_T::n_in * CONFIG_T::n_out],
762+
const typename CONFIG_T::weight_t RWI[CONFIG_T::n_out * CONFIG_T::n_out],
763+
const typename CONFIG_T::weight_t RWF[CONFIG_T::n_out * CONFIG_T::n_out],
764+
const typename CONFIG_T::weight_t RWC[CONFIG_T::n_out * CONFIG_T::n_out],
765+
const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out],
766+
const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out],
767+
const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) {
768+
res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
769+
res_T hidden_state_temp[CONFIG_T::n_out] hls_register;
770+
res_T cell_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
771+
res_T cell_state_temp[CONFIG_T::n_out] hls_register;
772+
res_T h[CONFIG_T::n_out] hls_register;
773+
res_T c[CONFIG_T::n_out] hls_register;
774+
data_T in[CONFIG_T::n_in] hls_register;
775+
776+
// Set initially hidden state (output) to zero
777+
INIT_LOOP:
778+
#pragma unroll
779+
for (int x = 0; x < CONFIG_T::n_out; x++) {
780+
hidden_state[x][0] = hidden_state_initial[x];
781+
cell_state[x][0] = cell_state_initial[x];
782+
}
783+
784+
// Input dimension
785+
#pragma disable_loop_pipelining
786+
for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
787+
// Data at current time step
788+
for (int x = 0; x < CONFIG_T::n_in; x++) {
789+
in[x] = data[x + i * CONFIG_T::n_in];
790+
}
791+
792+
// Hidden state at current time step
793+
#pragma unroll
794+
for (int x = 0; x < CONFIG_T::n_out; x++) {
795+
hidden_state_temp[x] = hidden_state[x][i];
796+
cell_state_temp[x] = cell_state[x][i];
797+
}
798+
799+
// Do LSTM
800+
lstm_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO,
801+
BI, BF, BC, BO);
802+
803+
// Write result
804+
#pragma unroll
805+
for (int x = 0; x < CONFIG_T::n_out; x++) {
806+
hidden_state[x][i + 1] = h[x];
807+
cell_state[x][i + 1] = c[x];
808+
}
809+
}
810+
811+
if (CONFIG_T::return_sequences == 0) {
812+
// Output when return_sequences is false
813+
#pragma unroll
814+
for (int x = 0; x < CONFIG_T::n_out; x++) {
815+
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
816+
}
817+
} else {
818+
// Output when return_sequences is true
819+
#pragma unroll
820+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
821+
for (int h = 0; h < CONFIG_T::n_out; h++) {
822+
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
823+
}
824+
}
825+
}
826+
}
827+
714828
} // namespace nnet
715829

716830
#endif

0 commit comments

Comments
 (0)