Skip to content

Commit 22a1054

Browse files
authored
Merge pull request fastmachinelearning#909 from calad0i/fp_write_length_fix
Fix writer precision when fp bits >= 14
2 parents b4693de + 7a25ede commit 22a1054

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

hls4ml/model/types.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -559,14 +559,9 @@ def update_precision(self, new_precision):
559559
if isinstance(new_precision, (IntegerPrecisionType, XnorPrecisionType, ExponentPrecisionType)):
560560
self.precision_fmt = '{:.0f}'
561561
elif isinstance(new_precision, FixedPrecisionType):
562-
if new_precision.fractional > 0:
563-
# Use str to represent the float with digits, get the length
564-
# to right of decimal point
565-
lsb = 2**-new_precision.fractional
566-
decimal_spaces = len(str(lsb).split('.')[1])
567-
self.precision_fmt = f'{{:.{decimal_spaces}f}}'
568-
else:
569-
self.precision_fmt = '{:.0f}'
562+
decimal_spaces = max(0, new_precision.fractional)
563+
self.precision_fmt = f'{{:.{decimal_spaces}f}}'
564+
570565
else:
571566
raise RuntimeError(f"Unexpected new precision type: {new_precision}")
572567

test/pytest/test_weight_writer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from glob import glob
2+
from pathlib import Path
3+
4+
import keras
5+
import numpy as np
6+
import pytest
7+
8+
import hls4ml
9+
10+
test_root_path = Path(__file__).parent
11+
test_root_path = Path('/tmp/trash')
12+
13+
14+
@pytest.mark.parametrize('k', [0, 1])
15+
@pytest.mark.parametrize('i', [4, 8, 10])
16+
@pytest.mark.parametrize('f', [-2, 0, 2, 7, 14])
17+
def test_weight_writer(k, i, f):
18+
k, b, i = k, k + i + f, k + i
19+
w = np.array([[np.float32(2.0**-f)]])
20+
u = '' if k else 'u'
21+
dtype = f'{u}fixed<{b}, {i}>'
22+
hls_config = {'LayerName': {'dense': {'Precision': {'weight': dtype}}}}
23+
model = keras.Sequential([keras.layers.Dense(1, input_shape=(1,), name='dense')])
24+
model.layers[0].kernel.assign(keras.backend.constant(w))
25+
output_dir = str(test_root_path / f'hls4ml_prj_test_weight_writer_{dtype}')
26+
model_hls = hls4ml.converters.convert_from_keras_model(model, hls_config=hls_config, output_dir=output_dir)
27+
model_hls.write()
28+
w_paths = glob(str(Path(output_dir) / 'firmware/weights/w*.txt'))
29+
print(w_paths[0])
30+
assert len(w_paths) == 1
31+
w_loaded = np.loadtxt(w_paths[0], delimiter=',').reshape(1, 1)
32+
print(f'{w[0,0]:.14}', f'{w_loaded[0,0]:.14}')
33+
assert np.all(w == w_loaded)

0 commit comments

Comments
 (0)