Skip to content

Commit dce3dbf

Browse files
nghielmeJanFSchultejmitrevs
authored andcommitted
Automatic type inference for param_t in Parametrised Activations (fastmachinelearning#1139)
* Added automatic inference of `param_t` constant for parametrised activations * pre-commit fixes * Fix the case the param is a power of 2 * Fix for a specific case related to no bits in the mantissa * Update subproject commit reference in example-models * first, untested version of constant precison * try using Fxp for precision setting * fix bug in max attribute of unsigned FixedPrecisionType * add unit test for precision from constant * integrate suggested test_precision_from_constant_unit change --------- Co-authored-by: Jan-Frederik Schulte <jschulte@cern.ch> Co-authored-by: Jovan Mitrevski <jmitrevs@fnal.gov> Co-authored-by: Jovan Mitrevski <j.p.mitrevski@gmail.com>
1 parent 2a2fdf1 commit dce3dbf

File tree

4 files changed

+73
-5
lines changed

4 files changed

+73
-5
lines changed

hls4ml/model/optimizer/passes/infer_precision.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Iterable
33

44
import numpy as np
5+
from fxpmath import Fxp
56

67
from hls4ml.model.optimizer import ConfigurableOptimizerPass
78
from hls4ml.model.types import (
@@ -573,9 +574,17 @@ def _infer_par_act_precision(self, node, types_to_infer):
573574
# For threshold relu, set the parameter precision to be the input precision by default;
574575
# for other parametrized activations, just allow the default precision to be used.
575576
# Can override these values in the configuration by explicitly setting them.
576-
if 'param_t' in types_to_infer and node.get_attr('activation').lower() == 'thresholdedrelu':
577-
in_type = node.get_input_variable().type.precision
578-
node.attributes['param_t'].precision = in_type
577+
if 'param_t' in types_to_infer:
578+
if node.get_attr('activation').lower() == 'thresholdedrelu':
579+
# For threshold relu, set the parameter precision to be the input precision by default;
580+
in_type = node.get_input_variable().type.precision
581+
node.attributes['param_t'].precision = in_type
582+
inferred_types.append('param_t')
583+
else:
584+
# find a constant to represent the values
585+
param = node.get_attr('activ_param')
586+
precision = _get_precision_from_constant(param)
587+
node.attributes['param_t'].precision = precision
579588
inferred_types.append('param_t')
580589

581590
return inferred_types
@@ -594,3 +603,33 @@ def _infer_prelu_act_precision(self, node, types_to_infer):
594603
inferred_types.append('param_t')
595604

596605
return inferred_types
606+
607+
608+
def _get_precision_from_constant(value: int | float, max_width=8):
609+
"""A utility function to find a fixed type to store the constant
610+
611+
Arguments:
612+
value (int or float): the constant value
613+
max_width (int, optional): the maximum fixed width (+ 1 if signed). Defaults to 8
614+
615+
Returns:
616+
FixedPrecisionType: the type to use
617+
"""
618+
if value == 0:
619+
return FixedPrecisionType(width=1, integer=1, signed=False)
620+
621+
signed = value < 0
622+
absval = abs(value)
623+
# check if power of 2
624+
mantissa, exp = np.frexp(absval)
625+
if mantissa == 0.5: # is it a power of 2?
626+
# One could consider returning an ExponentPrecisionType here.
627+
# Decided on FixedPrecisionType everywhere since ExponentPrecisionType is less supported
628+
return FixedPrecisionType(1 + signed, exp, signed)
629+
630+
# now is the general case. First try Fxp
631+
fxpval = Fxp(value, signed=signed)
632+
if isinstance(fxpval.n_word, int) and fxpval.n_word <= max_width:
633+
return FixedPrecisionType(fxpval.n_word, signed + fxpval.n_int, signed)
634+
635+
return FixedPrecisionType(signed + max_width, signed + exp, signed)

hls4ml/model/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def min(self):
270270

271271
@property
272272
def max(self):
273-
return 2.0 ** (self.integer - 1) - 2.0**-self.fractional
273+
return 2.0 ** (self.integer - self.signed) - 2.0**-self.fractional
274274

275275

276276
class XnorPrecisionType(PrecisionType):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ classifiers = [
2525
"Topic :: Software Development :: Libraries :: Python Modules",
2626
]
2727
dynamic = [ "version" ]
28-
dependencies = [ "h5py", "numpy", "pydigitalwavetools==1.1", "pyyaml", "quantizers" ]
28+
dependencies = [ "fxpmath", "h5py", "numpy", "pydigitalwavetools==1.1", "pyyaml", "quantizers" ]
2929

3030
optional-dependencies.da = [ "da4ml>=0.2.1,<=0.4" ]
3131
optional-dependencies.doc = [

test/pytest/test_auto_precision.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tensorflow.keras.models import Sequential
1818

1919
import hls4ml
20+
from hls4ml.model.optimizer.passes.infer_precision import _get_precision_from_constant
2021

2122
test_root_path = Path(__file__).parent
2223

@@ -254,3 +255,31 @@ def test_auto_precision_dense(keras_model_dense, data_1d, io_type, backend):
254255
y_keras = model.predict(data).flatten()
255256
y_hls = hls_model.predict(data).flatten()
256257
np.testing.assert_allclose(y_keras, y_hls, rtol=2e-2, atol=5e-2, verbose=True)
258+
259+
260+
@pytest.mark.parametrize(
261+
"val, expected_width",
262+
[
263+
(0, 1),
264+
(-1024, 2),
265+
(1024, 1),
266+
(0.03125, 1),
267+
(-0.03125, 2),
268+
(1.25, 3),
269+
(-1.25, 4),
270+
(1.1, 8),
271+
(-1.1, 9),
272+
],
273+
)
274+
def test_precision_from_constant_unit(val, expected_width):
275+
"""Test determining precision needed for a constant."""
276+
max_width = 8
277+
fp = _get_precision_from_constant(val, max_width)
278+
279+
assert fp.min <= val <= fp.max
280+
assert fp.width == expected_width
281+
assert fp.signed == (val < 0)
282+
283+
quantum = 2.0**-fp.fractional
284+
if expected_width < max_width:
285+
assert val % quantum == 0

0 commit comments

Comments
 (0)