Skip to content

Commit cc874fc

Browse files
bo3zvloncar
authored andcommitted
Quartus streaming Dense layer
1 parent e34e0c0 commit cc874fc

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

hls4ml/backends/quartus/passes/core_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
dense_function_template = 'nnet::dense_{strategy}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
3838

39-
dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h']
39+
dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h']
4040

4141
class DenseConfigTemplate(LayerConfigTemplate):
4242
def __init__(self):
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#ifndef NNET_DENSE_STREAM_H_
2+
#define NNET_DENSE_STREAM_H_
3+
4+
#include "nnet_common.h"
5+
#include "nnet_types.h"
6+
#include "nnet_dense.h"
7+
8+
namespace nnet {
9+
10+
template<class data_T, class res_T, typename CONFIG_T>
11+
void dense_resource(
12+
stream<data_T> &data_stream,
13+
stream<res_T> &res_stream,
14+
const typename CONFIG_T::weight_t weights[CONFIG_T::n_in*CONFIG_T::n_out],
15+
const typename CONFIG_T::bias_t biases[CONFIG_T::n_out])
16+
{
17+
hls_register typename data_T::value_type data[CONFIG_T::n_in];
18+
hls_register typename res_T::value_type res[CONFIG_T::n_out];
19+
20+
DataPrepare:
21+
#pragma ii 1
22+
for(int i_in = 0; i_in < CONFIG_T::n_in / data_T::size; i_in++) {
23+
data_T data_pack = data_stream.read();
24+
DataPack:
25+
#pragma unroll
26+
for (int i_pack = 0; i_pack < data_T::size; i_pack++) {
27+
data[i_in * data_T::size + i_pack] = data_pack[i_pack];
28+
}
29+
}
30+
31+
dense_resource<typename data_T::value_type, typename res_T::value_type, CONFIG_T>(data, res, weights, biases);
32+
33+
ResWrite:
34+
#pragma ii 1
35+
for(unsigned i_out = 0; i_out < CONFIG_T::n_out / res_T::size; i_out++) {
36+
res_T res_pack;
37+
ResPack:
38+
#pragma unroll
39+
for (int i_pack = 0; i_pack < res_T::size; i_pack++) {
40+
res_pack[i_pack] = res[i_out * res_T::size + i_pack];
41+
}
42+
43+
res_stream.write(res_pack);
44+
}
45+
}
46+
47+
48+
}
49+
50+
#endif

test/pytest/test_keras_api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
test_root_path = Path(__file__).parent
1717

1818
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
19-
def test_dense(backend):
19+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
20+
def test_dense(backend, io_type):
2021
model = tf.keras.models.Sequential()
2122
model.add(Dense(2,
2223
input_shape=(1,),
@@ -37,9 +38,9 @@ def test_dense(backend):
3738
keras_prediction = model.predict(X_input)
3839

3940
config = hls4ml.utils.config_from_keras_model(model)
40-
output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}')
41+
output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}_{io_type}')
4142

42-
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend)
43+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
4344

4445
hls_model.compile()
4546

@@ -66,7 +67,8 @@ def test_dense(backend):
6667
Activation(activation='sigmoid', name='Activation')])
6768
#ThresholdedReLU(theta=1.0)])
6869
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
69-
def test_activations(activation_function, backend):
70+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
71+
def test_activations(activation_function, backend, io_type):
7072
model = tf.keras.models.Sequential()
7173
model.add(Dense(64,
7274
input_shape=(1,),
@@ -79,8 +81,8 @@ def test_activations(activation_function, backend):
7981
X_input = np.random.rand(100,1)
8082
keras_prediction = model.predict(X_input)
8183
config = hls4ml.utils.config_from_keras_model(model)
82-
output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}_{}'.format(activation_function.__class__.__name__, backend))
83-
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend)
84+
output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}_{}_{}'.format(activation_function.__class__.__name__, backend, io_type))
85+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
8486
hls_model.compile()
8587
hls_prediction = hls_model.predict(X_input)
8688

0 commit comments

Comments
 (0)