Skip to content

Commit 568f97f

Browse files
authored
Add support for filt_height==1 for streaming quartus conv2d (fastmachinelearning#886)
1 parent 62d5e03 commit 568f97f

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_stream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ template <class data_T, typename CONFIG_T>
6969
void shift_line_buffer_2d(
7070
const data_T &in_elem,
7171
nnet::shift_reg<typename data_T::value_type, CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right>
72-
line_buffer[CONFIG_T::filt_height - 1][CONFIG_T::n_chan],
72+
line_buffer[MAX(CONFIG_T::filt_height - 1, 1)][CONFIG_T::n_chan],
7373
typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan]) {
7474
// For every channel, insert the incoming pixel at end of the shift buffer
7575
UpdateBuffer:

test/pytest/test_pointwiseconv.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111

1212
padds_options = ['same', 'valid']
1313
chans_options = ['channels_last']
14-
io_type_options = ['io_parallel', 'io_stream']
1514
strides1d_options = [(1,), (2,)]
1615
strides2d_options = [(1, 1), (2, 2)]
17-
strategy_options = ['Latency', 'Resource']
1816

1917

2018
@pytest.mark.parametrize('chans', chans_options)
@@ -24,6 +22,7 @@
2422
'backend, io_type, strategy',
2523
[
2624
('Quartus', 'io_parallel', 'resource'),
25+
('Quartus', 'io_stream', 'resource'),
2726
('Vivado', 'io_parallel', 'resource'),
2827
('Vitis', 'io_parallel', 'resource'),
2928
('Vivado', 'io_parallel', 'latency'),
@@ -54,7 +53,7 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy):
5453
X_input = np.random.rand(100, *input_shape)
5554
keras_prediction = model.predict(X_input)
5655

57-
default_precision = 'ac_fixed<32,16,true>' if backend == 'Quartus' else 'ap_fixed<32,16>'
56+
default_precision = 'fixed<32,16>'
5857
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision)
5958
config['Model']['Strategy'] = strategy
6059

@@ -70,7 +69,9 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy):
7069
hls_model.compile()
7170
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
7271

73-
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
72+
if not (backend == 'Quartus' and io_type == 'io_stream'):
73+
# Quartus io_stream does not currently have a special pointwise implementation
74+
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
7475
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)
7576

7677

@@ -81,6 +82,7 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy):
8182
'backend, io_type, strategy',
8283
[
8384
('Quartus', 'io_parallel', 'resource'),
85+
('Quartus', 'io_stream', 'resource'),
8486
('Vivado', 'io_parallel', 'resource'),
8587
('Vivado', 'io_parallel', 'latency'),
8688
('Vivado', 'io_stream', 'latency'),
@@ -107,7 +109,7 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy):
107109
X_input = np.random.rand(100, *input_shape)
108110
keras_prediction = model.predict(X_input)
109111

110-
default_precision = 'ac_fixed<32, 9, true>' if backend == 'Quartus' else 'ap_fixed<32, 9>'
112+
default_precision = 'fixed<32, 9>'
111113

112114
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision)
113115
config['Model']['Strategy'] = strategy
@@ -125,7 +127,9 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy):
125127
hls_model.compile()
126128
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
127129

128-
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
130+
if not (backend == 'Quartus' and io_type == 'io_stream'):
131+
# Quartus io_stream does not currently have a special pointwise implementation
132+
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
129133
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)
130134

131135

0 commit comments

Comments
 (0)