Skip to content

Commit 5ade441

Browse files
committed
Remove final reshapes even for Quartus
1 parent 9ca7af2 commit 5ade441

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
3+
from hls4ml.model.optimizer import OptimizerPass
4+
from hls4ml.model.layers import Reshape
5+
6+
7+
class RemoveFinalReshape(OptimizerPass):
8+
''' Remove reshape if final layer '''
9+
def match(self, node):
10+
# match if reshape is final node
11+
return isinstance(node, Reshape) and not node.get_output_nodes()
12+
13+
def transform(self, model, node):
14+
if model.config.get_config_value('IOType') == 'io_parallel':
15+
print('WARNING: Final layer is a Reshape, which does not affect the output for io_parallel; removing it')
16+
# remove, but don't rewire because it's the output layer
17+
model.remove_node(node, rewire=False)
18+
return True
19+
elif model.config.get_config_value('IOType') == 'io_stream':
20+
print('WARNING: Final layer is a Reshape, which may incur a large resource cost for io_stream; consider removing it')
21+
return False

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _register_flows(self):
6060
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)
6161

6262
optimization_passes = [
63+
'quartus:remove_final_reshape',
6364
'quartus:optimize_pointwise_conv',
6465
]
6566
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)

hls4ml/backends/vivado/passes/repack_stream.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from hls4ml.model.optimizer import OptimizerPass
44
from hls4ml.model.layers import Layer, Merge, Reshape, Concatenate, register_layer
5-
from hls4ml.backends import get_backend
65
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
76

87
class Repack(Layer):
@@ -82,7 +81,6 @@ def register_repack_stream(backend):
8281
register_layer('Broadcast', Broadcast)
8382

8483
# Register the optimization passes
85-
backend.register_pass('remove_final_reshape', RemoveFinalReshape)
8684
backend.register_pass('reshape_stream', ReshapeStream)
8785
backend.register_pass('broadcast_stream', BroadcastStream)
8886

@@ -162,19 +160,3 @@ def supported_broadcast(inp_shape, target_shape):
162160
node.inputs[idx] = brdcst_out
163161

164162
return True
165-
166-
class RemoveFinalReshape(OptimizerPass):
167-
''' Remove reshape if final layer '''
168-
def match(self, node):
169-
# match if reshape is final node
170-
return isinstance(node, Reshape) and not node.get_output_nodes()
171-
172-
def transform(self, model, node):
173-
if model.config.get_config_value('IOType') == 'io_parallel':
174-
print('WARNING: Final layer is a Reshape, which does not affect the output for io_parallel; removing it')
175-
# remove, but don't rewire because it's the output layer
176-
model.remove_node(node, rewire=False)
177-
return True
178-
elif model.config.get_config_value('IOType') == 'io_stream':
179-
print('WARNING: Final layer is a Reshape, which may incur a large resource cost for io_stream; consider removing it')
180-
return False

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def _register_flows(self):
3434
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
3535

3636
streaming_passes = [
37-
'vivado:remove_final_reshape',
3837
'vivado:reshape_stream',
3938
'vivado:clone_output',
4039
'vivado:insert_zero_padding_before_conv1d',
@@ -51,6 +50,7 @@ def _register_flows(self):
5150
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)
5251

5352
optimization_passes = [
53+
'vivado:remove_final_reshape',
5454
'vivado:optimize_pointwise_conv',
5555
]
5656
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)

0 commit comments

Comments
 (0)