@@ -53,6 +53,33 @@ def _dims_to_layout(model, node, ndims):
5353 return DataLayout .NC
5454 else :
5555 return DataLayout .UNKNOWN
56+ else :
57+ # Also try to propagate input data layout for "FINN ops" (if number of dims matches)
58+ input_ndims = len (model .get_tensor_shape (node .input [0 ]))
59+ if input_ndims == ndims and (layout := model .get_tensor_layout (node .input [0 ])):
60+ # TODO: There are multi-input operations, why should the first
61+ # determine the output layout?
62+ return layout
63+ else :
64+ # Fallback: guess based on number of output dims
65+ if ndims == 4 :
66+ return DataLayout .NHWC
67+ elif ndims == 3 :
68+ return DataLayout .NWC
69+ elif ndims == 2 :
70+ return DataLayout .NC
71+ else :
72+ return DataLayout .UNKNOWN
73+ else :
74+ # Check whether there is a layout annotation for the first input
75+ # TODO: There are multi-input operations, why should the first
76+ # determine the output layout?
77+ # TODO: Shouldn't we at least check that the number of dims matches?
78+ if layout := model .get_tensor_layout (node .input [0 ]):
79+ # If annotation present: propagate input layout to output
80+ # TODO: this won't work for concat, squeeze/unsqueeze/reshape...
81+ return layout
82+ # Fallback to the same defaults as for the FINN-Ops above
5683 else :
5784 if ndims == 4 :
5885 return DataLayout .NHWC
@@ -62,10 +89,6 @@ def _dims_to_layout(model, node, ndims):
6289 return DataLayout .NC
6390 else :
6491 return DataLayout .UNKNOWN
65- else :
66- # propagate input layout to output
67- # TODO this won't work for concat, squeeze/unsqueeze/reshape...
68- return model .get_tensor_layout (node .input [0 ])
6992
7093
7194def _infer_node_data_layout (model , node ):
0 commit comments