Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hls4ml/backends/fpga/passes/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def transform(self, model, node):
)
for i in range(len(output_map[output])):
key = output + '_cpy' + str(i + 1)
clone_layer.attributes[key].type = node.attributes['result_t']
clone_layer.attributes[key].type = node.get_output_variable().type
model.insert_node(clone_layer)
transformed = True

Expand Down
2 changes: 2 additions & 0 deletions hls4ml/converters/onnx_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def get_input_shape(graph, node):
"""
rv = []
for inp in node.input:
if inp == '':
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this used? When are you asking for the input node of a node with no inputs?

Copy link
Author

@nghielme nghielme Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens with Resize node, https://onnx.ai/onnx/operators/onnx__Resize.html
I agree it is not super clean but I didn't figure out another solution. The handling of RoI field should also be considered at QONNX level

continue
try:
value_info_idx = next((i for i, x in enumerate(graph.value_info) if x.name == inp))
dim = list(d.dim_value for d in graph.value_info[value_info_idx].type.tensor_type.shape.dim)
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/model/optimizer/passes/batchnorm_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def transform(self, model, node):
const_prec = const_node.get_output_variable().type.precision

new_val = (
const_node.attributes['value'] * node.weights['scale'].data_unquantized + node.weights['bias'].data_unquantized
const_node.get_attr('value') * node.weights['scale'].data_unquantized + node.weights['bias'].data_unquantized
)

const_node.set_attr('value', new_val)
Expand Down
30 changes: 28 additions & 2 deletions hls4ml/model/optimizer/passes/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from hls4ml.model.layers import Activation, BatchNormalization, Conv1D, Conv2D, Dense
from hls4ml.model.layers import (
Activation,
BatchNormalization,
Concatenate,
Conv1D,
Conv2D,
Dense,
DepthwiseConv1D,
DepthwiseConv2D,
Input,
Pooling1D,
Pooling2D,
Resize,
)
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import UnspecifiedPrecisionType

Expand All @@ -15,7 +28,20 @@ def transform(self, model, node):
return True


_safe_parents = (Dense, Conv1D, Conv2D, BatchNormalization, Activation)
_safe_parents = (
Input,
Dense,
Conv1D,
Conv2D,
DepthwiseConv1D,
DepthwiseConv2D,
BatchNormalization,
Activation,
Pooling1D,
Pooling2D,
Resize,
Concatenate,
)


class MergeLinearActivation(OptimizerPass):
Expand Down
4 changes: 2 additions & 2 deletions hls4ml/model/optimizer/passes/quant_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def transform(self, model, node):
integer = bitwidth
scale = node.get_attr('scale')
if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all():
_, exp = np.frexp(np.squeeze(scale))
_, exp = np.frexp(np.unique(scale.ravel()).item())
integer = bitwidth + exp - 1

precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode)
Expand Down Expand Up @@ -336,7 +336,7 @@ def transform(self, model, node):

inshape = node.get_input_variable().shape

attributes_rescale = {'n_filt': -1}
attributes_rescale = {'n_filt': -1, 'quantizer': quantizer}

rescale_config = copy.deepcopy(model.config.get_layer_config(node))
rescale_name = f'{node.name}_rescale'
Expand Down