Skip to content

Commit 3fa2902

Browse files
authored
Merge pull request fastmachinelearning#1161 from JanFSchulte/transposefix
Bug fixes for channel-last conversions in pytorch
2 parents 4b7e12d + 31219e3 commit 3fa2902

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

hls4ml/model/optimizer/passes/convert_to_channels_last.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,17 @@ def transform(self, model, node):
9797
if (
9898
isinstance(node, Reshape)
9999
and len(node.attributes['target_shape']) == 1
100-
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal"
100+
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "off"
101101
):
102102
previous_node = node.get_input_node(node.inputs[0])
103103
input = previous_node.name
104104
outshape = previous_node.get_output_variable().shape
105105

106+
if (model.config.config['IOType'] == 'io_stream') and len(outshape) == 3:
107+
raise Exception(
108+
'No 3D transpose available in io_stream, this model cannot be converted to channels-last'
109+
)
110+
106111
if len(outshape) == 2:
107112
attributes = {'perm': [1, 0]}
108113
else:

hls4ml/utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def config_from_pytorch_model(
283283
default_precision='ap_fixed<16,6>',
284284
default_reuse_factor=1,
285285
channels_last_conversion='full',
286-
transpose_outputs=True,
286+
transpose_outputs=False,
287287
max_precision=None,
288288
):
289289
"""Create an HLS conversion config given the PyTorch model.

test/pytest/test_pytorch_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def test_pooling(pooling, padds, backend):
498498
model.eval()
499499
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
500500

501-
config = config_from_pytorch_model(model, input_shape_forHLS)
501+
config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True)
502502
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_pooling_{pooling.__name__}_padds_{padds}_backend_{backend}')
503503
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend)
504504
hls_model.compile()

0 commit comments

Comments
 (0)