Skip to content

Commit 7dfd7dc

Browse files
authored
Merge pull request #143 from iksnagreb/feature/generalized_multi_threshold_layouts
[MultiThreshold] Generalize data layouts for node execution
2 parents cc2ec17 + 26cd75d commit 7dfd7dc

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/qonnx/custom_op/general/multithreshold.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_nodeattr_types(self):
100100
"out_dtype": ("s", True, ""),
101101
"out_scale": ("f", False, 1.0),
102102
"out_bias": ("f", False, 0.0),
103-
"data_layout": ("s", False, "NCHW"),
103+
"data_layout": ("s", False, ""),
104104
}
105105

106106
def make_shape_compatible_op(self, model):
@@ -130,12 +130,32 @@ def execute_node(self, context, graph):
130130
# retrieve attributes if output scaling is used
131131
out_scale = self.get_nodeattr("out_scale")
132132
out_bias = self.get_nodeattr("out_bias")
133-
# transpose input if NHWC data layout is chosen
133+
134+
# Consider the data layout for transposing the input into the format
135+
# accepted by the multithreshold function above, i.e, the channel
136+
# dimension is along the axis with index 1.
134137
data_layout = self.get_nodeattr("data_layout")
135-
channels_last = True if data_layout[-1] == "C" else False
136-
# calculate output
138+
# If there is no layout annotation, guess based on rank of the
139+
# tensor
140+
if not data_layout and len(v.shape) < 5:
141+
# Maps tensor rank to layout annotation
142+
rank_to_layout = {0: None, 1: None, 2: "NC", 3: "NWC", 4: "NCHW"}
143+
# Lookup the layout required by this input shape
144+
data_layout = rank_to_layout[len(v.shape)]
145+
# Lookup the index of the channel dimension in the data layout
146+
# Note: Assumes there is at most one "C" which denotes the channel
147+
# dimension
148+
if data_layout is not None:
149+
cdim = data_layout.index("C") if "C" in data_layout else 1
150+
else:
151+
cdim = 1
152+
# Rearrange the input to the expected (N, C, ...) layout
137153
orig_shape = v.shape
138-
output = multithreshold(v, thresholds, out_scale, out_bias, channels_last)
154+
v = v.swapaxes(cdim, 1)
155+
# Now we can use the multithreshold function to calculate output
156+
output = multithreshold(v, thresholds, out_scale, out_bias)
157+
# Rearrange the output back to the original layout
158+
output = output.swapaxes(cdim, 1)
139159
assert output.shape == orig_shape, "Shape changed during thresholding!"
140160
context[node.output[0]] = output
141161

0 commit comments

Comments
 (0)