Skip to content

Commit ebd4de3

Browse files
authored
Merge pull request #4 from jmitrevs/initialRecurrOneAPI
initial state rnns for oneAPI
2 parents 75b5dca + 96b6903 commit ebd4de3

File tree

3 files changed

+142
-10
lines changed

3 files changed

+142
-10
lines changed

hls4ml/backends/oneapi/passes/recurrent_templates.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@
9696
}};\n'''
9797

9898
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
99+
gru_function_initial_state_template = (
100+
'nnet::gru_init_state<{input_t}, {h_t}, {output_t}, {config}>({input}, {init_state}, {output}, {w}, {wr}, {b}, {br});'
101+
)
99102
gru_task_sequence_template = 'task_sequence<nnet::gru_stream<{input_pipe}, {output_pipe}, {config}>> {name};'
100103
gru_stream_function_template = '{name}.async({w}, {wr}, {b}, {br});'
101104

@@ -163,15 +166,23 @@ def format(self, node):
163166
class GRUFunctionTemplate(FunctionCallTemplate):
164167
def __init__(self):
165168
super().__init__(GRU, include_header=recurrent_include_list)
166-
self.template = gru_function_template
167169

168170
def format(self, node):
169171
params = self._default_function_params(node)
172+
if params['pass_initial_states'] == 'true':
173+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
174+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
170175
params['w'] = node.get_weights('weight').name
171176
params['b'] = node.get_weights('bias').name
172177
params['wr'] = node.get_weights('recurrent_weight').name
173178
params['br'] = node.get_weights('recurrent_bias').name
174-
return self.template.format(**params)
179+
180+
if params['pass_initial_states'] == 'true':
181+
template = gru_function_initial_state_template
182+
else:
183+
template = gru_function_template
184+
185+
return template.format(**params)
175186

176187

177188
class GRUTaskSequenceTemplate(TaskSequenceTemplate):
@@ -235,6 +246,10 @@ def format(self, node):
235246
}};\n"""
236247

237248
lstm_function_template = 'nnet::lstm<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
249+
lstm_function_initial_state_template = (
250+
'nnet::lstm_init_state<{input_t}, {h_t}, {hc_t}, {output_t}, {config}>'
251+
'({input}, {init_state}, {init_cell}, {output}, {weights});'
252+
)
238253

239254

240255
class LSTMConfigTemplate(LayerConfigTemplate):
@@ -275,11 +290,16 @@ def format(self, node):
275290
class LSTMFunctionTemplate(FunctionCallTemplate):
276291
def __init__(self):
277292
super().__init__(LSTM, include_header=recurrent_include_list)
278-
self.template = lstm_function_template
279293

280294
def format(self, node):
281295
params = self._default_function_params(node)
282296

297+
if params['pass_initial_states'] == 'true':
298+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
299+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
300+
params['init_cell'] = node.get_input_variable(node.inputs[2]).name
301+
params['hc_t'] = node.get_input_variable(node.inputs[2]).type.name
302+
283303
types = ['i', 'f', 'c', 'o']
284304
params['weights'] = ''
285305
for t in types:
@@ -289,7 +309,12 @@ def format(self, node):
289309
for t in types:
290310
params['weights'] += 'bias_{}_{}{}'.format(t, str(node.index), ',' if t != 'o' else '')
291311

292-
return self.template.format(**params)
312+
if params['pass_initial_states'] == 'true':
313+
template = lstm_function_initial_state_template
314+
else:
315+
template = lstm_function_template
316+
317+
return template.format(**params)
293318

294319

295320
################################################

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,41 @@ void gru(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weig
200200
}
201201
}
202202

203+
template <class data_T, class h_T, class res_T, typename CONFIG_T>
204+
void gru_init_state(const data_T &data, const h_T &hin, res_T &res, const typename CONFIG_T::weight_t &weights,
205+
const typename CONFIG_T::recurrent_weight_t &recurrent_weights, const typename CONFIG_T::bias_t &bias,
206+
const typename CONFIG_T::recurrent_bias_t &recurrent_bias) {
207+
208+
[[intel::fpga_register]] data_T x;
209+
210+
[[intel::fpga_register]] h_T h = hin;
211+
212+
// Loop depedency - cannot pipeline
213+
[[intel::disable_loop_pipelining]] for (int t = 0; t < CONFIG_T::n_timesteps; t++) {
214+
// Get data at current time step
215+
#pragma unroll
216+
for (int j = 0; j < CONFIG_T::n_in; j++) {
217+
x[j] = data[j + t * CONFIG_T::n_in];
218+
}
219+
220+
nnet::gru_cell<data_T, h_T, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);
221+
222+
if (CONFIG_T::return_sequences) {
223+
#pragma unroll
224+
for (int i = 0; i < CONFIG_T::n_units; i++) {
225+
res[CONFIG_T::n_units * t + i] = h[i];
226+
}
227+
}
228+
}
229+
230+
if (!CONFIG_T::return_sequences) {
231+
#pragma unroll
232+
for (int i = 0; i < (CONFIG_T::n_units); i++) {
233+
res[i] = h[i];
234+
}
235+
}
236+
}
237+
203238
//----------------------
204239
// SimpleRNN
205240
//----------------------
@@ -561,6 +596,78 @@ void lstm(const data_T &data, res_T &res, const typename CONFIG_T::weight_i_t &W
561596
}
562597
}
563598

599+
template <class data_T, class h_T, class hc_T, class res_T, class CONFIG_T>
600+
void lstm_init_state(const data_T &data, const h_T &hidden_state_initial, const hc_T &cell_state_initial, res_T &res,
601+
const typename CONFIG_T::weight_i_t &WI, const typename CONFIG_T::weight_f_t &WF,
602+
const typename CONFIG_T::weight_c_t &WC, const typename CONFIG_T::weight_o_t &WO,
603+
const typename CONFIG_T::recurrent_weight_i_t &RWI, const typename CONFIG_T::recurrent_weight_f_t &RWF,
604+
const typename CONFIG_T::recurrent_weight_c_t &RWC, const typename CONFIG_T::recurrent_weight_o_t &RWO,
605+
const typename CONFIG_T::bias_i_t &BI, const typename CONFIG_T::bias_f_t &BF,
606+
const typename CONFIG_T::bias_c_t &BC, const typename CONFIG_T::bias_o_t &BO) {
607+
608+
// Note: currently this does not support recurrent bias
609+
610+
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
611+
612+
[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
613+
[[intel::fpga_register]] h_T hidden_state_temp;
614+
[[intel::fpga_register]] h_T cell_state[CONFIG_T::n_timesteps + 1];
615+
[[intel::fpga_register]] h_T cell_state_temp; // should this be updated to a differnt type
616+
[[intel::fpga_register]] h_T h;
617+
[[intel::fpga_register]] h_T c;
618+
[[intel::fpga_register]] in_T in;
619+
620+
// Set initially hidden state (output) to zero
621+
INIT_LOOP:
622+
#pragma unroll
623+
for (int x = 0; x < CONFIG_T::n_out; x++) {
624+
hidden_state[0][x] = hidden_state_initial[x];
625+
cell_state[0][x] = cell_state_initial[x];
626+
}
627+
628+
// Input dimension
629+
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
630+
// Data at current time step
631+
for (int x = 0; x < CONFIG_T::n_in; x++) {
632+
in[x] = data[x + i * CONFIG_T::n_in];
633+
}
634+
635+
// Hidden state at current time step
636+
#pragma unroll
637+
for (int x = 0; x < CONFIG_T::n_out; x++) {
638+
hidden_state_temp[x] = hidden_state[i][x];
639+
cell_state_temp[x] = cell_state[i][x];
640+
}
641+
642+
// Do LSTM
643+
lstm_cell<in_T, h_T, CONFIG_T>(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO, BI,
644+
BF, BC, BO);
645+
646+
// Write result
647+
#pragma unroll
648+
for (int x = 0; x < CONFIG_T::n_out; x++) {
649+
hidden_state[i + 1][x] = h[x];
650+
cell_state[i + 1][x] = c[x];
651+
}
652+
}
653+
654+
if (CONFIG_T::return_sequences == 0) {
655+
// Output when return_sequences is false
656+
#pragma unroll
657+
for (int x = 0; x < CONFIG_T::n_out; x++) {
658+
res[x] = hidden_state[CONFIG_T::n_timesteps][x];
659+
}
660+
} else {
661+
// Output when return_sequences is true
662+
#pragma unroll
663+
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
664+
for (int h = 0; h < CONFIG_T::n_out; h++) {
665+
res[x * CONFIG_T::n_out + h] = hidden_state[x + 1][h];
666+
}
667+
}
668+
}
669+
}
670+
564671
} // namespace nnet
565672

566673
#endif

test/pytest/test_recurrent_pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def forward(self, x):
3131
return output
3232

3333

34-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
34+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
3535
@pytest.mark.parametrize('io_type', ['io_parallel'])
3636
def test_gru(backend, io_type):
3737
model = GRUNet()
@@ -56,7 +56,7 @@ def test_gru(backend, io_type):
5656
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=1e-1)
5757

5858

59-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
59+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
6060
@pytest.mark.parametrize('io_type', ['io_stream'])
6161
def test_gru_stream(backend, io_type):
6262
model = GRUNetStream()
@@ -98,7 +98,7 @@ def forward(self, x):
9898
return output
9999

100100

101-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
101+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
102102
@pytest.mark.parametrize('io_type', ['io_parallel'])
103103
def test_lstm(backend, io_type):
104104
model = LSTM()
@@ -132,10 +132,10 @@ def test_lstm(backend, io_type):
132132
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=1e-1)
133133

134134

135-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
135+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
136136
@pytest.mark.parametrize('io_type', ['io_stream'])
137137
def test_lstm_stream(backend, io_type):
138-
if not (backend == "Quartus" and io_type == "io_stream"):
138+
if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"):
139139
model = LSTMStream()
140140
model.eval()
141141

@@ -174,7 +174,7 @@ def forward(self, x, h0):
174174
@pytest.mark.parametrize('backend', ['Quartus'])
175175
@pytest.mark.parametrize('io_type', ['io_parallel'])
176176
def test_rnn(backend, io_type):
177-
if not (backend == "Quartus" and io_type == "io_stream"):
177+
if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"):
178178
model = RNN()
179179
model.eval()
180180

0 commit comments

Comments
 (0)