Skip to content

Commit 0d1bc8b

Browse files
authored
Fix keras model loading issue with loading model with KerasH5 (#664)
* Fix keras model loading issue with loading model with KerasH5 * Add unit test for KerasH5 loader * comment on h5py version causing KerasH5 loading issue
1 parent d54843d commit 0d1bc8b

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

hls4ml/converters/keras_to_hls.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ def keras_to_hls(config):
232232
if model_arch is None:
233233
raise ValueError('No model found in config file.')
234234
else:
235-
model_arch = json.loads(model_arch.decode('utf-8'))
235+
# model_arch is string by default since h5py 3.0.0, keeping this condition for compatibility.
236+
if isinstance(model_arch, bytes):
237+
model_arch = model_arch.decode('utf-8')
238+
model_arch = json.loads(model_arch)
236239
reader = KerasFileReader(config)
237240
else:
238241
raise ValueError('No model found in config file.')
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
import hls4ml
3+
import tensorflow as tf
4+
import numpy as np
5+
from pathlib import Path
6+
7+
8+
test_root_path = Path(__file__).parent
9+
10+
test_root_path = Path('/tmp')
11+
12+
13+
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
14+
def test_keras_h5_loader(backend):
15+
input_shape = (10,)
16+
model = tf.keras.models.Sequential([
17+
tf.keras.layers.InputLayer(input_shape=input_shape),
18+
tf.keras.layers.Activation(activation='relu'),
19+
])
20+
21+
hls_config = hls4ml.utils.config_from_keras_model(model, granularity='name')
22+
23+
config = {'OutputDir': 'KerasH5_loader_test',
24+
'ProjectName': 'KerasH5_loader_test',
25+
'Backend': backend,
26+
'ClockPeriod': 25.0,
27+
'IOType': 'io_parallel',
28+
'HLSConfig': hls_config,
29+
'KerasH5': str(test_root_path / 'KerasH5_loader_test.h5'),
30+
'output_dir': str(test_root_path / 'KerasH5_loader_test')}
31+
32+
model.save(config['KerasH5'])
33+
hls_model = hls4ml.converters.keras_to_hls(config)
34+
hls_model.compile()
35+
data = np.random.rand(1000, 10).astype(np.float32)
36+
pred = hls_model.predict(data)
37+
np.testing.assert_allclose(pred, model.predict(data), rtol=5e-3, atol=5e-3)

0 commit comments

Comments
 (0)