Skip to content

Commit fe9d3e7

Browse files
authored
skip BatchNorm fusion when input/output is used multiple times (#481)
* skip bn fusion when input/ouput used mult. times * format
1 parent 13b3a0a commit fe9d3e7

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

hls4ml/backends/fpga/passes/clone.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,7 @@ def transform(self, model, node):
5555
if model.config.get_config_value('IOType') != 'io_stream':
5656
return False
5757

58-
output_map = {}
59-
for output in node.outputs:
60-
output_map[output] = []
61-
for layer in model.get_layers():
62-
for inp in layer.inputs:
63-
if output == inp:
64-
output_map[output].append(layer)
58+
output_map = node.get_output_use_map()
6559

6660
transformed = False
6761
for output in node.outputs:

hls4ml/model/layers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,28 @@ def get_input_variable(self, input_name=None):
131131
else:
132132
return self.model.get_layer_output_variable(self.inputs[0])
133133

134+
def get_output_use_map(self):
135+
output_map = {}
136+
for output in self.outputs:
137+
output_map[output] = []
138+
for layer in self.model.get_layers():
139+
for inp in layer.inputs:
140+
if output == inp:
141+
output_map[output].append(layer)
142+
return output_map
143+
134144
def get_output_nodes(self, output_name=None):
135-
if output_name is None:
136-
output_name = self.outputs[0]
137-
return [node for node in self.model.graph.values() if node.inputs[0] == output_name]
145+
output_nodes = []
146+
if output_name is not None:
147+
outputs = [output_name]
148+
else:
149+
outputs = self.outputs
150+
for output in outputs:
151+
for layer in self.model.get_layers():
152+
for inp in layer.inputs:
153+
if output == inp:
154+
output_nodes.append(layer)
155+
return output_nodes
138156

139157
def get_output_variable(self, output_name=None):
140158
if output_name is not None:

hls4ml/model/optimizer/passes/bn_fuse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def match(self, node):
1212
def transform(self, model, node):
1313
# Fuse weight and bias of Dense/Conv1D/Conv2D layer with BN values
1414
parent_node = node.get_input_node()
15+
parent_map = parent_node.get_output_use_map()
16+
node_map = node.get_output_use_map()
17+
if len(parent_map[parent_node.name]) > 1 or len(node_map[node.name]) > 1:
18+
return False
1519

1620
parent_weight = parent_node.weights['weight']
1721
parent_bias = parent_node.weights['bias']

hls4ml/model/optimizer/passes/qkeras.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def match(self, node):
226226
def transform(self, model, node):
227227
bn0 = node.get_input_node()
228228
bn1 = node
229+
bn0_map = bn0.get_output_use_map()
230+
bn1_map = bn1.get_output_use_map()
231+
if len(bn0_map[bn0.name]) > 1 or len(bn1_map[bn1.name]) > 1:
232+
return False
229233

230234
s0 = bn0.weights['scale'].data
231235
b0 = bn0.weights['bias'].data

0 commit comments

Comments
 (0)