Skip to content

Commit ef2e8f4

Browse files
committed
allow io_stream if used as model output
1 parent 7b58c1d commit ef2e8f4

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

hls4ml/backends/fpga/passes/inplace_stream_flatten.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,18 @@ class InplaceStreamFlatten(OptimizerPass):
1111
"""
1212

1313
def match(self, node):
14-
# Reshape acts as a Flatten layer when the result has 1 dimension
14+
# Layers require flatten data can gather it from the stream, no need for repacking.
15+
# Reshape acts as a Flatten layer when the result has 1 dimension. Make it a inplace tensor if it happens.
16+
17+
if node.model.config.get_config_value('IOType') != 'io_stream':
18+
return False
1519
if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1):
16-
# Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes
20+
# If is not flatten
21+
return False
22+
if node.name in node.model.outputs:
23+
# If used as model output. Output shape shall be preserved in this case.
1724
return False
18-
return node.model.config.get_config_value('IOType') == 'io_stream'
25+
return True
1926

2027
def transform(self, model, node):
2128
outvar = node.get_output_variable()

hls4ml/backends/fpga/passes/repack_stream.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class ReshapeStream(OptimizerPass):
4949

5050
def match(self, node):
5151
# do not run optimizer pass for a flatten layer (1 output dimension)
52-
return isinstance(node, Reshape) and len(node.get_output_variable().shape) > 1
52+
if not isinstance(node, Reshape):
53+
return False
54+
return len(node.get_output_variable().shape) > 1 or node.name in node.model.outputs
5355

5456
def transform(self, model, node):
5557
if model.config.get_config_value('IOType') != 'io_stream':

hls4ml/model/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,8 @@ def remove_node(self, node, rewire=True):
545545
if len(inputs) == 1:
546546
# Connect inputs -> $outputs
547547
if node.name in self.outputs:
548+
msg = f'Remove leaf node {node.name} will connect its input node {inputs[0]} to output, but it already is.'
549+
assert inputs[0] not in self.outputs, msg
548550
self.outputs = [inputs[0] if name == node.name else name for name in self.outputs]
549551

550552
if len(outputs) == 1 and len(inputs) == 1:

test/pytest/test_multiout_network.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def model_corner_cases():
2828
out2 = keras.layers.Dense(16, activation='relu')(out1)
2929
out2 = keras.layers.Add()([out2, in2])
3030
out3 = keras.layers.Dense(2)(out1)
31-
model = keras.models.Model(inputs=[in1, in2], outputs=[out1, out2, out3])
31+
out4 = keras.layers.Dense(2)(out2)
32+
out4 = keras.layers.Flatten()(out4)
33+
model = keras.models.Model(inputs=[in1, in2], outputs=[out1, out2, out3, out4])
3234
return model
3335

3436

@@ -76,7 +78,8 @@ def test_multi_output_nn_corner_cases(model_corner_cases, data_corner_cases, bac
7678
- when an node removal/insertion is triggered internally
7779
- a reshape in io_parallel, or flatten in io_stream layer's output is used multiple times
7880
- and as layer output
79-
- and by layer taking multiple inputs
81+
- and by layer taking multiple inputs
82+
- a Flatten layer outputs to the model output in io_stream
8083
"""
8184
output_dir = str(test_root_path / f'hls4mlprj_multiout_network_2_{backend}_{io_type}_{strategy}')
8285
hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}, 'Strategy': strategy}
@@ -92,3 +95,4 @@ def test_multi_output_nn_corner_cases(model_corner_cases, data_corner_cases, bac
9295
assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0)
9396
assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0)
9497
assert np.allclose(r_hls[2], r_keras[2], atol=1e-5, rtol=0)
98+
assert np.allclose(r_hls[3], r_keras[3], atol=1e-5, rtol=0)

0 commit comments

Comments
 (0)