Skip to content

Commit 752850f

Browse files
authored
Merge pull request #229 from iksnagreb/fix/layout-inference
Try to propagate input data layout to outputs for FINN ops
2 parents 6044098 + 2c116e0 commit 752850f

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

src/qonnx/transformation/infer_data_layouts.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,33 @@ def _dims_to_layout(model, node, ndims):
5353
return DataLayout.NC
5454
else:
5555
return DataLayout.UNKNOWN
56+
else:
57+
# Also try to propagate input data layout for "FINN ops" (if number of dims matches)
58+
input_ndims = len(model.get_tensor_shape(node.input[0]))
59+
if input_ndims == ndims and (layout := model.get_tensor_layout(node.input[0])):
60+
# TODO: There are multi-input operations, why should the first
61+
# determine the output layout?
62+
return layout
63+
else:
64+
# Fallback: guess based on number of output dims
65+
if ndims == 4:
66+
return DataLayout.NHWC
67+
elif ndims == 3:
68+
return DataLayout.NWC
69+
elif ndims == 2:
70+
return DataLayout.NC
71+
else:
72+
return DataLayout.UNKNOWN
73+
else:
74+
# Check whether there is a layout annotation for the first input
75+
# TODO: There are multi-input operations, why should the first
76+
# determine the output layout?
77+
# TODO: Shouldn't we at least check that the number of dims matches?
78+
if layout := model.get_tensor_layout(node.input[0]):
79+
# If annotation present: propagate input layout to output
80+
# TODO: this won't work for concat, squeeze/unsqueeze/reshape...
81+
return layout
82+
# Fallback to the same defaults as for the FINN-Ops above
5683
else:
5784
if ndims == 4:
5885
return DataLayout.NHWC
@@ -62,10 +89,6 @@ def _dims_to_layout(model, node, ndims):
6289
return DataLayout.NC
6390
else:
6491
return DataLayout.UNKNOWN
65-
else:
66-
# propagate input layout to output
67-
# TODO this won't work for concat, squeeze/unsqueeze/reshape...
68-
return model.get_tensor_layout(node.input[0])
6992

7093

7194
def _infer_node_data_layout(model, node):

0 commit comments

Comments
 (0)