Skip to content

Commit 62d5e03

Browse files
qberthetQuentin Berthet
andauthored
Fix profiling SeparableConv1D and SeparableConv2D (fastmachinelearning#891)
* Profiling: Fix suffixes for SeparableConv1D&2D * Profiling: transform list to dict where dict is expected --------- Co-authored-by: Quentin Berthet <[email protected]>
1 parent b67e730 commit 62d5e03

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

hls4ml/model/profiling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import seaborn as sb
1111

1212
from hls4ml.model.graph import ModelGraph
13-
from hls4ml.model.layers import GRU, LSTM
13+
from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D
1414

1515
try:
1616
import qkeras
@@ -184,6 +184,8 @@ def types_hlsmodel(model):
184184
for layer in model.get_layers():
185185
if isinstance(layer, GRU) or isinstance(layer, LSTM):
186186
suffix = ['w', 'rw', 'b', 'rb']
187+
elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D):
188+
suffix = ['dw', 'pw', 'db', 'pb']
187189
else:
188190
suffix = ['w', 'b']
189191
for iw, weight in enumerate(layer.get_weights()):
@@ -225,6 +227,8 @@ def weights_hlsmodel(model, fmt='longform', plot='boxplot'):
225227
for layer in model.get_layers():
226228
if isinstance(layer, GRU) or isinstance(layer, LSTM):
227229
suffix = ['w', 'rw', 'b', 'rb']
230+
elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D):
231+
suffix = ['dw', 'pw', 'db', 'pb']
228232
else:
229233
suffix = ['w', 'b']
230234
name = layer.name
@@ -346,6 +350,7 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
346350
outputs = _get_outputs(
347351
[layer for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], X, model.input
348352
)
353+
outputs = dict(zip([layer.name for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], outputs))
349354
for layer_name, y in outputs.items():
350355
print(f" {layer_name}")
351356
y = y.flatten()

0 commit comments

Comments
 (0)