Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 55 additions & 37 deletions hls4ml/model/optimizer/passes/infer_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,54 @@ def _all_supported_types(self, types: Iterable[PrecisionType]):
return False
return True

def _apply_max_precision_constraints(self, node, precision):
"""
Clamps the precision to the node's max_precision constraints.

Logic:
1. Width/Integer: Always constrained to the minimum of inferred vs max.
2. Rounding/Saturation: Inherited from max_precision ONLY if they differ from the defaults
(meaning the user likely set them explicitly).
3. Signedness: max_precision signed arg is always preferred.
"""
max_precision = self._get_maximum_precision(node)

if max_precision is None:
return precision

new_width = min(precision.width, max_precision.width)
new_integer = min(precision.integer, max_precision.integer)

# Default modes defined in FixedPrecisionType
default_type = FixedPrecisionType()
DEFAULT_RND = default_type.rounding_mode
DEFAULT_SAT = default_type.saturation_mode
DEFAULT_SAT_BITS = default_type.saturation_bits

if max_precision.rounding_mode != DEFAULT_RND:
new_rounding_mode = max_precision.rounding_mode
else:
new_rounding_mode = precision.rounding_mode

if max_precision.saturation_mode != DEFAULT_SAT:
new_saturation_mode = max_precision.saturation_mode
else:
new_saturation_mode = precision.saturation_mode

if max_precision.saturation_bits != DEFAULT_SAT_BITS:
new_saturation_bits = max_precision.saturation_bits
else:
new_saturation_bits = precision.saturation_bits

return FixedPrecisionType(
width=new_width,
integer=new_integer,
signed=max_precision.signed,
rounding_mode=new_rounding_mode,
saturation_mode=new_saturation_mode,
saturation_bits=new_saturation_bits,
)

def _infer_default_type(self, node, type_name):
model_config = node.model.config
default_precision = model_config.backend.convert_precision_string(model_config.model_precision['default'])
Expand Down Expand Up @@ -180,14 +228,8 @@ def _infer_common_precision(self, node, types_to_infer, n_ops):
bitwidth = integers + max(frac, bias_width - bias_integers)
signed = signed or bias_signed

# if max_precision is specified, limit the size to be less than max precisoin
max_precision = self._get_maximum_precision(node)
if max_precision is not None:
bitwidth = min(bitwidth, max_precision.width)
integers = min(integers, max_precision.integer)

# Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form.
new_type = FixedPrecisionType(bitwidth, integers, signed)
out_precision = FixedPrecisionType(bitwidth, integers, signed)
new_type = self._apply_max_precision_constraints(node, out_precision)
else:
new_type = self._get_default_precision(node)

Expand Down Expand Up @@ -334,15 +376,8 @@ def _infer_bn_precision(self, node, types_to_infer):
out_precision_width = out_precision_integer + max(
after_scale_width - after_scale_integer, bias_precision.fractional
)

# if max_precision is specified, limit the size to be less than max precisoin
max_precision = self._get_maximum_precision(node)
if max_precision is not None:
out_precision_width = min(out_precision_width, max_precision.width)
out_precision_integer = min(out_precision_integer, max_precision.integer)

# Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form.
out_precision = FixedPrecisionType(out_precision_width, out_precision_integer, out_precision_signed)
out_precision = self._apply_max_precision_constraints(node, out_precision)

else:
out_precision = self._get_default_precision(node)
Expand Down Expand Up @@ -413,24 +448,17 @@ def _infer_merge_precision(self, node, types_to_infer):
+ 1
)
new_width = new_int + max(input_1.fractional, input_2.fractional)
max_precision = self._get_maximum_precision(node)
if max_precision is not None:
new_width = min(new_width, max_precision.width)
new_int = min(new_int, max_precision.integer)
out_precision = FixedPrecisionType(new_width, new_int, new_signed)
out_precision = self._apply_max_precision_constraints(node, out_precision)
else:
out_precision = self._get_default_precision(node)
elif op == 'multiply':
if self._all_supported_types((input_1, input_2)):
new_signed = input_1.signed or input_2.signed
new_int = input_1.integer + input_2.integer
new_width = input_1.width + input_2.width
# if max_precision is specified, limit the size to be less than max precisoin
max_precision = self._get_maximum_precision(node)
if max_precision is not None:
new_width = min(new_width, max_precision.width)
new_int = min(new_int, max_precision.integer)
out_precision = FixedPrecisionType(new_width, new_int, new_signed)
out_precision = self._apply_max_precision_constraints(node, out_precision)
else:
out_precision = self._get_default_precision(node)
elif op in ('maximum', 'minimum'):
Expand Down Expand Up @@ -487,18 +515,13 @@ def _infer_cat_precision(self, node, types_to_infer):
new_width = max(input_1.fractional, input_2.fractional) + max(input_1_integer, input_2_integer)
new_int = max(input_1_integer, input_2_integer)

# if max_precision is specified, limit the size to be less than max precisoin
max_precision = self._get_maximum_precision(node)
if max_precision is not None:
new_width = min(new_width, max_precision.width)
new_int = min(new_int, max_precision.integer)

# some logic copied from former SetPrecisionConcat optimizer
newrmode = input_1.rounding_mode if input_1.rounding_mode != RoundingMode.TRN else input_2.rounding_mode
newsmode = input_1.saturation_mode if input_1.saturation_mode != SaturationMode.WRAP else input_2.saturation_mode
newsbits = input_1.saturation_bits if input_1.saturation_bits != 0 else input_2.saturation_bits

out_precision = FixedPrecisionType(new_width, new_int, new_signed, newrmode, newsmode, newsbits)
out_precision = self._apply_max_precision_constraints(node, out_precision)
else:
out_precision = self._get_default_precision(node)

Expand All @@ -520,13 +543,8 @@ def _infer_dot_precision(self, node, types_to_infer):
new_width = input_1.width + input_2.width + math.ceil(np.log2(n_in))
new_int = input_1.integer + input_2.integer + math.ceil(np.log2(n_in))

# if max_precision is specified, limit the size to be less than max precisoin
max_precision = self._get_maximum_precision(node)
if max_precision is not None:
new_width = min(new_width, max_precision.width)
new_int = min(new_int, max_precision.integer)

out_precision = FixedPrecisionType(new_width, new_int, new_signed)
out_precision = self._apply_max_precision_constraints(node, out_precision)
else:
out_precision = self._get_default_precision(node)
node.types['result_t'].name = node.name + '_result_t'
Expand Down
Loading