Skip to content

Commit de20cfc

Browse files
authored
Merge branch 'main' into fix_repack_precision
2 parents 3aa9130 + 0d48aff commit de20cfc

File tree

7 files changed

+133
-62
lines changed

7 files changed

+133
-62
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte
22

33
repos:
44
- repo: https://github.com/psf/black
5-
rev: 23.9.1
5+
rev: 23.10.0
66
hooks:
77
- id: black
88
language_version: python3
99
args: ['--line-length=125',
1010
'--skip-string-normalization']
1111

1212
- repo: https://github.com/pre-commit/pre-commit-hooks
13-
rev: v4.4.0
13+
rev: v4.5.0
1414
hooks:
1515
- id: check-added-large-files
1616
- id: check-case-conflict
@@ -30,7 +30,7 @@ repos:
3030
args: ["--profile", "black", --line-length=125]
3131

3232
- repo: https://github.com/asottile/pyupgrade
33-
rev: v3.14.0
33+
rev: v3.15.0
3434
hooks:
3535
- id: pyupgrade
3636
args: ["--py36-plus"]

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
static const unsigned out_width = {out_width};
4242
static const unsigned reuse_factor = {reuse};
4343
static const unsigned n_zeros = {nzeros};
44+
static const unsigned multiplier_limit =
45+
DIV_ROUNDUP(kernel_size * n_chan * n_filt, reuse_factor) - n_zeros / reuse_factor;
4446
static const bool store_weights_in_bram = false;
4547
static const unsigned strategy = nnet::{strategy};
4648
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};

hls4ml/model/profiling.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import seaborn as sb
1111

1212
from hls4ml.model.graph import ModelGraph
13-
from hls4ml.model.layers import GRU, LSTM
13+
from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D
1414

1515
try:
1616
import qkeras
@@ -184,6 +184,8 @@ def types_hlsmodel(model):
184184
for layer in model.get_layers():
185185
if isinstance(layer, GRU) or isinstance(layer, LSTM):
186186
suffix = ['w', 'rw', 'b', 'rb']
187+
elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D):
188+
suffix = ['dw', 'pw', 'db', 'pb']
187189
else:
188190
suffix = ['w', 'b']
189191
for iw, weight in enumerate(layer.get_weights()):
@@ -225,6 +227,8 @@ def weights_hlsmodel(model, fmt='longform', plot='boxplot'):
225227
for layer in model.get_layers():
226228
if isinstance(layer, GRU) or isinstance(layer, LSTM):
227229
suffix = ['w', 'rw', 'b', 'rb']
230+
elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D):
231+
suffix = ['dw', 'pw', 'db', 'pb']
228232
else:
229233
suffix = ['w', 'b']
230234
name = layer.name
@@ -343,21 +347,23 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
343347
# return summary statistics for matplotlib.axes.Axes.bxp
344348
# or histogram bin edges and heights
345349
data = []
346-
347-
for layer in model.layers:
348-
print(f" {layer.name}")
349-
if not isinstance(layer, keras.layers.InputLayer):
350-
y = _get_output(layer, X, model.input).flatten()
351-
y = abs(y[y != 0])
352-
if len(y) == 0:
353-
print(f'Activations for {layer.name} are only zeros, ignoring.')
354-
continue
355-
if fmt == 'longform':
356-
data['x'].extend(y.tolist())
357-
data['weight'].extend([layer.name for i in range(len(y))])
358-
elif fmt == 'summary':
359-
data.append(array_to_summary(y, fmt=plot))
360-
data[-1]['weight'] = layer.name
350+
outputs = _get_outputs(
351+
[layer for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], X, model.input
352+
)
353+
outputs = dict(zip([layer.name for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], outputs))
354+
for layer_name, y in outputs.items():
355+
print(f" {layer_name}")
356+
y = y.flatten()
357+
y = abs(y[y != 0])
358+
if len(y) == 0:
359+
print(f'Activations for {layer_name} are only zeros, ignoring.')
360+
continue
361+
if fmt == 'longform':
362+
data['x'].extend(y.tolist())
363+
data['weight'].extend([layer_name for i in range(len(y))])
364+
elif fmt == 'summary':
365+
data.append(array_to_summary(y, fmt=plot))
366+
data[-1]['weight'] = layer_name
361367

362368
if fmt == 'longform':
363369
data = pandas.DataFrame(data)
@@ -544,10 +550,10 @@ def _is_ignored_layer(layer):
544550
return False
545551

546552

547-
def _get_output(layer, X, model_input):
548-
"""Get output of partial model"""
549-
partial_model = keras.models.Model(inputs=model_input, outputs=layer.output)
550-
y = partial_model.predict(X)
553+
def _get_outputs(layers, X, model_input):
554+
"""Get outputs of intermediate layers"""
555+
partial_models = keras.models.Model(inputs=model_input, outputs=[layer.output for layer in layers])
556+
y = partial_models.predict(X)
551557
return y
552558

553559

@@ -562,37 +568,30 @@ def get_ymodel_keras(keras_model, X):
562568
Returns:
563569
dict: A dictionary in the form {"layer_name": ouput array of layer}.
564570
"""
565-
566571
ymodel = {}
567-
572+
traced_layers = []
573+
layer_names = []
568574
for layer in keras_model.layers:
569-
print(f"Processing {layer.name} in Keras model...")
570-
if not _is_ignored_layer(layer):
571-
# If the layer has activation integrated then separate them
572-
# Note that if the layer is a standalone activation layer then skip this
573-
if hasattr(layer, 'activation') and not (
574-
isinstance(layer, keras.layers.Activation) or isinstance(layer, qkeras.qlayers.QActivation)
575-
):
576-
if layer.activation:
577-
if layer.activation.__class__.__name__ == "linear":
578-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
579-
580-
else:
581-
temp_activation = layer.activation
582-
layer.activation = None
583-
# Get output for layer without activation
584-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
585-
586-
# Add the activation back
587-
layer.activation = temp_activation
588-
# Get ouput for activation
589-
ymodel[layer.name + f"_{temp_activation.__class__.__name__}"] = _get_output(
590-
layer, X, keras_model.input
591-
)
592-
else:
593-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
594-
else:
595-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
575+
if _is_ignored_layer(layer):
576+
continue
577+
# If the layer has activation integrated then separate them
578+
# Note that if the layer is a standalone activation layer then skip this
579+
name = layer.name
580+
if (
581+
hasattr(layer, "activation")
582+
and layer.activation.__name__ != "linear"
583+
and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation))
584+
):
585+
tmp_activation = layer.activation
586+
layer.activation = None
587+
ymodel.update({layer.name: _get_outputs([layer], X, keras_model.input)})
588+
layer.activation = tmp_activation
589+
name = layer.name + f"_{tmp_activation.__name__}"
590+
traced_layers.append(layer)
591+
layer_names.append(name)
592+
outputs = _get_outputs(traced_layers, X, keras_model.input)
593+
for name, output in zip(layer_names, outputs):
594+
ymodel[name] = output
596595
print("Done taking outputs for Keras model.")
597596
return ymodel
598597

hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_stream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ template <class data_T, typename CONFIG_T>
6969
void shift_line_buffer_2d(
7070
const data_T &in_elem,
7171
nnet::shift_reg<typename data_T::value_type, CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right>
72-
line_buffer[CONFIG_T::filt_height - 1][CONFIG_T::n_chan],
72+
line_buffer[MAX(CONFIG_T::filt_height - 1, 1)][CONFIG_T::n_chan],
7373
typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan]) {
7474
// For every channel, insert the incoming pixel at end of the shift buffer
7575
UpdateBuffer:

test/pytest/test_pointwiseconv.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111

1212
padds_options = ['same', 'valid']
1313
chans_options = ['channels_last']
14-
io_type_options = ['io_parallel', 'io_stream']
1514
strides1d_options = [(1,), (2,)]
1615
strides2d_options = [(1, 1), (2, 2)]
17-
strategy_options = ['Latency', 'Resource']
1816

1917

2018
@pytest.mark.parametrize('chans', chans_options)
@@ -24,6 +22,7 @@
2422
'backend, io_type, strategy',
2523
[
2624
('Quartus', 'io_parallel', 'resource'),
25+
('Quartus', 'io_stream', 'resource'),
2726
('Vivado', 'io_parallel', 'resource'),
2827
('Vitis', 'io_parallel', 'resource'),
2928
('Vivado', 'io_parallel', 'latency'),
@@ -54,7 +53,7 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy):
5453
X_input = np.random.rand(100, *input_shape)
5554
keras_prediction = model.predict(X_input)
5655

57-
default_precision = 'ac_fixed<32,16,true>' if backend == 'Quartus' else 'ap_fixed<32,16>'
56+
default_precision = 'fixed<32,16>'
5857
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision)
5958
config['Model']['Strategy'] = strategy
6059

@@ -70,7 +69,9 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy):
7069
hls_model.compile()
7170
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
7271

73-
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
72+
if not (backend == 'Quartus' and io_type == 'io_stream'):
73+
# Quartus io_stream does not currently have a special pointwise implementation
74+
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
7475
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)
7576

7677

@@ -81,6 +82,7 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy):
8182
'backend, io_type, strategy',
8283
[
8384
('Quartus', 'io_parallel', 'resource'),
85+
('Quartus', 'io_stream', 'resource'),
8486
('Vivado', 'io_parallel', 'resource'),
8587
('Vivado', 'io_parallel', 'latency'),
8688
('Vivado', 'io_stream', 'latency'),
@@ -107,7 +109,7 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy):
107109
X_input = np.random.rand(100, *input_shape)
108110
keras_prediction = model.predict(X_input)
109111

110-
default_precision = 'ac_fixed<32, 9, true>' if backend == 'Quartus' else 'ap_fixed<32, 9>'
112+
default_precision = 'fixed<32, 9>'
111113

112114
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision)
113115
config['Model']['Strategy'] = strategy
@@ -125,7 +127,9 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy):
125127
hls_model.compile()
126128
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
127129

128-
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
130+
if not (backend == 'Quartus' and io_type == 'io_stream'):
131+
# Quartus io_stream does not currently have a special pointwise implementation
132+
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
129133
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)
130134

131135

test/pytest/test_sepconv1d.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
import tensorflow as tf
6+
from tensorflow.keras.layers import SeparableConv1D
7+
8+
import hls4ml
9+
10+
test_root_path = Path(__file__).parent
11+
12+
keras_conv1d = [SeparableConv1D]
13+
padds_options = ['same', 'valid']
14+
chans_options = ['channels_last']
15+
io_type_options = ['io_stream']
16+
strides_options = [(1), (2)]
17+
kernel_options = [(1), (3)]
18+
bias_options = [False]
19+
20+
21+
@pytest.mark.parametrize('conv1d', keras_conv1d)
22+
@pytest.mark.parametrize('chans', chans_options)
23+
@pytest.mark.parametrize('padds', padds_options)
24+
@pytest.mark.parametrize('strides', strides_options)
25+
@pytest.mark.parametrize('kernels', kernel_options)
26+
@pytest.mark.parametrize('bias', bias_options)
27+
@pytest.mark.parametrize('io_type', io_type_options)
28+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
29+
def test_sepconv1d(conv1d, chans, padds, strides, kernels, bias, io_type, backend):
30+
model = tf.keras.models.Sequential()
31+
input_shape = (28, 3)
32+
model.add(
33+
conv1d(
34+
filters=32,
35+
kernel_size=kernels,
36+
strides=strides,
37+
padding=padds,
38+
input_shape=input_shape,
39+
kernel_initializer='normal',
40+
use_bias=bias,
41+
data_format=chans,
42+
)
43+
)
44+
45+
model.compile(optimizer='adam', loss='mse')
46+
X_input = np.random.rand(100, *input_shape)
47+
keras_prediction = model.predict(X_input)
48+
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>')
49+
stride_cfg = str(strides).replace(', ', '_').replace('(', '').replace(')', '')
50+
kernel_cfg = str(kernels).replace(', ', '_').replace('(', '').replace(')', '')
51+
output_dir = str(
52+
test_root_path
53+
/ 'hls4mlprj_{}_{}_strides_{}_kernels_{}_{}_padding_{}_{}'.format(
54+
conv1d.__name__.lower(), chans, stride_cfg, kernel_cfg, padds, backend, io_type
55+
)
56+
)
57+
hls_model = hls4ml.converters.convert_from_keras_model(
58+
model, hls_config=config, output_dir=output_dir, io_type=io_type, backend=backend
59+
)
60+
hls_model.compile()
61+
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
62+
63+
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)

test/pytest/test_trace.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212

1313

1414
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
15-
def test_trace(backend):
15+
@pytest.mark.parametrize('activation', ['relu', None])
16+
def test_trace(backend, activation):
1617
'''Test the tracing feature with a simple Keras model.'''
1718
model = tf.keras.models.Sequential()
1819
model.add(
1920
Dense(
2021
2,
2122
input_shape=(1,),
2223
name='Dense',
24+
activation=activation,
2325
use_bias=True,
2426
kernel_initializer=tf.keras.initializers.RandomUniform(minval=1, maxval=10),
2527
bias_initializer='zeros',
@@ -48,6 +50,7 @@ def test_trace(backend):
4850
hls_model.compile()
4951
hls4ml_pred, hls4ml_trace = hls_model.trace(X_input)
5052
keras_trace = hls4ml.model.profiling.get_ymodel_keras(model, X_input)
51-
52-
np.testing.assert_allclose(hls4ml_trace['Dense'], keras_trace['Dense'], rtol=1e-2, atol=0.01)
53+
assert keras_trace.keys() == hls4ml_trace.keys()
54+
for key in hls4ml_trace.keys():
55+
np.testing.assert_allclose(hls4ml_trace[key], keras_trace[key], rtol=1e-2, atol=0.01)
5356
np.testing.assert_allclose(hls4ml_pred, keras_prediction, rtol=1e-2, atol=0.01)

0 commit comments

Comments
 (0)