Skip to content

Commit a5d32f2

Browse files
authored
Merge pull request #39 from eki-project/feature/streamline-plus
[Streamline] Introduce StreamlinePlus: Exhaustive streamlining
2 parents 1e3085f + 2ffb095 commit a5d32f2

File tree

7 files changed

+1191
-52
lines changed

7 files changed

+1191
-52
lines changed

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ exclude =
164164
dist
165165
.eggs
166166
docs/conf.py
167+
per-file-ignores =
168+
src/finn/transformation/streamline/streamline_plus.py: F405, F403
167169

168170
[pyscaffold]
169171
# PyScaffold's parameters when the project was created.

src/finn/transformation/streamline/absorb.py

Lines changed: 128 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,85 +29,161 @@
2929
import numpy as np
3030
import qonnx.core.data_layout as DataLayout
3131
import warnings
32+
33+
# Protobuf onnx graph node type
34+
from onnx import NodeProto # noqa
3235
from onnx import helper as oh
3336
# Protobuf onnx graph node type
3437
from onnx import NodeProto # noqa
3538
# QONNX wrapper of ONNX model graphs
3639
from qonnx.core.modelwrapper import ModelWrapper
3740
from qonnx.core.datatype import DataType
38-
39-
# QONNX wrapper of ONNX model graphs
4041
from qonnx.core.modelwrapper import ModelWrapper
4142
from qonnx.custom_op.registry import getCustomOp
4243
from qonnx.transformation.base import Transformation
4344
from qonnx.transformation.infer_datatypes import InferDataTypes
4445
from qonnx.transformation.infer_shapes import InferShapes
4546
from qonnx.util.basic import get_by_name
4647

47-
# Protobuf onnx graph node type
48-
from onnx import NodeProto # noqa
48+
from finn.transformation.util import group_inputs_by_category
4949

5050

51+
# Note: Old name kept for compatibility reasons but actually allows to absorb
52+
# any bias irrespective of signedness which might result in changed signedness
53+
# of the output type
5154
class AbsorbSignBiasIntoMultiThreshold(Transformation):
5255
"""Absorb scalar bias originating from signed int export back into
5356
MultiThreshold and re-evaluate the output datatype."""
5457

55-
def apply(self, model):
58+
def apply(self, model: ModelWrapper):
59+
# Get the model graph out of the model wrapper object
5660
graph = model.graph
57-
node_ind = 0
61+
# Keep track of whether the graph has been modified
5862
graph_modified = False
59-
for n in graph.node:
60-
# search for (MultiThreshold, Add) pair
61-
node_ind += 1
63+
# Iterate all nodes in the graph keeping track of the index
64+
for index, node in enumerate(graph.node):
65+
# Only non-branching threshold operations are supported
6266
if (
63-
n.op_type == "MultiThreshold"
64-
and not model.is_fork_node(n)
65-
and not model.is_join_node(n)
67+
node.op_type == "MultiThreshold"
68+
and not model.is_fork_node(node)
69+
and not model.is_join_node(node)
6670
):
67-
consumer = model.find_consumer(n.output[0])
71+
# We now we are not forking, so there is at most one consumer
72+
consumer = model.find_consumer(node.output[0])
73+
# At the end of the graph we might have no consumer. If we have
74+
# one, only handle Adds, turn Sub into Add first...
6875
if consumer is not None and consumer.op_type == "Add":
69-
mt_node = n
70-
add_node = consumer
71-
threshold_name = mt_node.input[1]
72-
add_weight_name = add_node.input[1]
73-
T = model.get_initializer(threshold_name)
74-
A = model.get_initializer(add_weight_name)
75-
if (A is None) or (T is None):
76-
warnings.warn("Threshold or add bias not constant, skipping")
76+
# Try to get the parameter tensor for the addition: Sanity
77+
# check whether this is present, even though we already
78+
# tested for non-joining
79+
bias = model.get_initializer(consumer.input[1])
80+
81+
# Warn and skip if there is no constant bias present
82+
if bias is None:
83+
warnings.warn(
84+
f"{self.__class__.__name__}: Bias not constant for"
85+
f" {consumer.name}, skipping."
86+
)
87+
# Skip to next node, nothing changed so far, no need to
88+
# break here
7789
continue
78-
end_name = add_node.output[0]
79-
# we can only absorb scalar adds
80-
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
81-
if not is_scalar:
90+
91+
# Try to get the parameter tensor for the thresholds: Sanity
92+
# check whether this is present, even though we already
93+
# tested for non-joining
94+
thresholds = model.get_initializer(node.input[1])
95+
96+
# Warn and skip if there is no constant bias present
97+
if thresholds is None:
98+
warnings.warn(
99+
f"{self.__class__.__name__}: Thresholds not"
100+
f" constant for {node.name}, skipping."
101+
)
102+
# Skip to next node, nothing changed so far, no need to
103+
# break here
82104
continue
83-
bias = A.flatten()[0]
84-
# set MultiThreshold bias property
85-
mt_inst = getCustomOp(mt_node)
86-
bias += mt_inst.get_nodeattr("out_bias")
87-
mt_inst.set_nodeattr("out_bias", bias)
105+
106+
# Check whether the bias is as scalar as we cannot absorb
107+
# full tensors into node attributes
108+
if not (bias.ndim == 0 or all(x == 1 for x in bias.shape)):
109+
warnings.warn(
110+
f"{self.__class__.__name__}: Bias not scalar"
111+
f" for {consumer.name}, skipping."
112+
)
113+
# Skip to next node, nothing changed so far, no need to
114+
# break here
115+
continue
116+
117+
# Flatten effectively scalar bias tensors and extract to
118+
# have "plain" scalar
119+
bias = bias.flatten()[0]
120+
# CustomOp instance of the thresholding node required for
121+
# convenient attribute manipulation
122+
threshold_op = getCustomOp(node)
123+
# Shift the output bias of the thresholding operator
124+
out_bias = threshold_op.get_nodeattr("out_bias") + bias
125+
# Derive the new output range due to shifting the bias
126+
# Note: We count thresholds steps on top of the bias
127+
new_min = out_bias
128+
new_max = out_bias + thresholds.shape[-1]
129+
130+
# Allows the signedness to change depending on the new
131+
# output range [new_min,new_max]
132+
if abs(new_min) > abs(new_max):
133+
odt = DataType.get_smallest_possible(new_min)
134+
else:
135+
odt = DataType.get_smallest_possible(new_max)
136+
137+
# Check whether the new range can be represented with the
138+
# derived integer datatype
139+
if not (odt.allowed(new_max) and odt.allowed(new_min)):
140+
# Cannot be represented, warn and skip transforming
141+
warnings.warn(
142+
f"{self.__class__.__name__}: Cannot absorb bias"
143+
f" from {consumer.name} into {node.name}: {bias}"
144+
)
145+
# Skip to the next candidate node
146+
continue
147+
148+
# Remember the old datatype for some further checks and info
149+
old_odt = threshold_op.get_nodeattr("out_dtype")
150+
151+
# Check whether the datatype changes as this is something
152+
# the "user" should be aware of
153+
if odt.name != old_odt:
154+
warnings.warn(
155+
f"{self.__class__.__name__}: Output datatype for"
156+
f" {node.name} changing from {old_odt} to {odt}"
157+
)
158+
159+
# Up until now we did not modify the nodes/grap, just did
160+
# some checks and derive the new bias and datatype. Start
161+
# inserting this back into the graph now...
162+
163+
# Set new bias and datatype attributes into the threshold
164+
# operator
165+
threshold_op.set_nodeattr("out_bias", out_bias)
166+
threshold_op.set_nodeattr("out_dtype", odt.name)
167+
# Remove the bias operator and rewire the graph to skip the
168+
# now-missing node
169+
node.output[0] = consumer.output[0]
170+
graph.node.remove(consumer)
171+
# Update the datatype at the output of the threshold
172+
# operation
173+
model.set_tensor_datatype(node.output[0], odt)
174+
175+
# Graph modified so we need to apply this transformation
176+
# again
88177
graph_modified = True
89-
# compute new DataType for MultiThreshold output
90-
steps = T.shape[-1]
91-
new_min = bias
92-
new_max = steps + bias
93-
odt = DataType.get_smallest_possible(steps).name.replace("UINT", "INT")
94-
odt = DataType[odt]
95-
assert odt.allowed(new_max) and odt.allowed(
96-
new_min
97-
), """Could
98-
not compute new MultiThreshold DataType (min = %d max = %d)""" % (
99-
new_min,
100-
new_max,
101-
)
102-
mt_inst.set_nodeattr("out_dtype", odt.name)
103-
# remove Add node, rewire MultiThreshold
104-
graph.node.remove(add_node)
105-
mt_node.output[0] = end_name
106-
# set datatype
107-
model.set_tensor_datatype(end_name, odt)
108-
if graph_modified:
109-
model = model.transform(InferDataTypes())
110-
return (model, graph_modified)
178+
# Better break now to clean up and recover annotations first
179+
break
180+
# As we might have changes types and removed nodes better redo some
181+
# annotations
182+
model = model.transform(InferDataTypes())
183+
model = model.transform(InferShapes())
184+
# Transformed model and indication whether the transformation should be
185+
# applied again
186+
return model, graph_modified
111187

112188

113189
# Groups inputs by categories, i.e., groups dynamic inputs first, followed by

src/finn/transformation/streamline/collapse_repeated.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,24 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
# Helper for creating ONNX nodes
2930
from onnx import helper as oh
31+
32+
# QONNX arbitrary precision data types
3033
from qonnx.core.datatype import DataType
34+
35+
# QONNX wrapper of ONNX model graphs
36+
from qonnx.core.modelwrapper import ModelWrapper
37+
38+
# QONNX graph transformation base class
3139
from qonnx.transformation.base import Transformation
40+
41+
# QONNX graph transformations for inferring datatypes and shapes
3242
from qonnx.transformation.infer_shapes import InferShapes
3343

44+
# Gets items from protobuf by name
45+
from qonnx.util.basic import get_by_name
46+
3447

3548
class CollapseRepeatedOp(Transformation):
3649
"""Collapse repeated consecutive operations with constant parameters into
@@ -106,3 +119,94 @@ class CollapseRepeatedMul(CollapseRepeatedOp):
106119

107120
def __init__(self):
108121
super().__init__("Mul", lambda x, y: y * x)
122+
123+
124+
# Collapses repeated transpose operations into a single transpose operation
125+
# having the same effect
126+
class CollapseRepeatedTranspose(Transformation):
127+
# Applies the transform to a whole model graph
128+
def apply(self, model: ModelWrapper): # noqa
129+
# Get the model graph out of the model wrapper object
130+
graph = model.graph
131+
# Keep track of whether the graph has been modified
132+
graph_modified = False
133+
# Iterate all nodes in the graph keeping track of the index
134+
for index, node in enumerate(graph.node):
135+
# Applies to Transpose operation types
136+
if node.op_type == "Transpose":
137+
# Currently does not handle fork- or join-nodes
138+
if model.is_fork_node(node) or model.is_join_node(node):
139+
# Softly skip this node
140+
continue
141+
# As this is not a fork-node, there can be at most one successor
142+
successor = model.find_direct_successors(node)
143+
# If Transpose is the final operation in the graph, there might
144+
# be no successor
145+
if successor is None:
146+
# Softly skip this node
147+
continue
148+
# Now there is exactly one successor which needs to be extracted
149+
# from the list
150+
successor = successor[0]
151+
# Successor must be a Transpose to be collapsed
152+
if successor.op_type != "Transpose":
153+
# Softly skip this node
154+
continue
155+
# Get the (optional) permutation indices of the first transpose
156+
# in case it is a multi-axis transpose
157+
perm1 = get_by_name(node.attribute, "perm")
158+
# Convert permutation indices to list of integers
159+
perm1 = perm1.ints if perm1 is not None else None
160+
161+
# Get the (optional) permutation indices of the second transpose
162+
# in case it is a multi-axis transpose
163+
perm2 = get_by_name(successor.attribute, "perm")
164+
# Convert permutation indices to list of integers
165+
perm2 = perm2.ints if perm2 is not None else None
166+
167+
# Get the shape of the input tensor
168+
shape = model.get_tensor_shape(
169+
# fmt: off
170+
node.input[0], fix_missing_init_shape=True
171+
# fmt: on
172+
)
173+
# List of dimension indices in order
174+
dims = range(len(shape))
175+
176+
# Substitute the permutation indices by the reversed index list
177+
# if they are not given: This is default behavior, see the docs:
178+
# https://onnx.ai/onnx/operators/onnx__Transpose.html
179+
perm1 = list(reversed(dims)) if perm1 is None else perm1
180+
perm2 = list(reversed(dims)) if perm2 is None else perm2
181+
182+
# Combined permutation permutes the first permutation of the
183+
# dimensions according to the second permutation
184+
perm = [perm1[i] for i in perm2]
185+
186+
# Create a new Transpose operator replacing the other two
187+
transpose = oh.make_node(
188+
# Name of the operator type
189+
"Transpose",
190+
# Connect to the inputs to the first transpose
191+
inputs=node.input,
192+
# Connect to the outputs of the second transpose
193+
outputs=successor.output,
194+
# Insert the new permutation indices
195+
perm=perm,
196+
)
197+
# Insert the collapsed transpose operator
198+
graph.node.insert(index + 2, transpose)
199+
# Remove the two original transpose operators
200+
graph.node.remove(node)
201+
graph.node.remove(successor)
202+
# Track whether the graph has been modified, never resets to
203+
# False
204+
graph_modified = True
205+
# Break the loop after adding and removing nodes to start over
206+
# with a clean index
207+
break
208+
# Need to redo the shape inference after potentially removing nodes
209+
model = model.transform(InferShapes()) # noqa: Shadows model
210+
# Return the transformed model and indicate whether the graph actually
211+
# has been transformed
212+
return model, graph_modified

0 commit comments

Comments
 (0)