|
29 | 29 | import numpy as np |
30 | 30 | import qonnx.core.data_layout as DataLayout |
31 | 31 | import warnings |
| 32 | + |
| 33 | +# Protobuf onnx graph node type |
| 34 | +from onnx import NodeProto # noqa |
32 | 35 | from onnx import helper as oh |
33 | 36 | # Protobuf onnx graph node type |
34 | 37 | from onnx import NodeProto # noqa |
35 | 38 | # QONNX wrapper of ONNX model graphs |
36 | 39 | from qonnx.core.modelwrapper import ModelWrapper |
37 | 40 | from qonnx.core.datatype import DataType |
38 | | - |
39 | | -# QONNX wrapper of ONNX model graphs |
40 | 41 | from qonnx.core.modelwrapper import ModelWrapper |
41 | 42 | from qonnx.custom_op.registry import getCustomOp |
42 | 43 | from qonnx.transformation.base import Transformation |
43 | 44 | from qonnx.transformation.infer_datatypes import InferDataTypes |
44 | 45 | from qonnx.transformation.infer_shapes import InferShapes |
45 | 46 | from qonnx.util.basic import get_by_name |
46 | 47 |
|
47 | | -# Protobuf onnx graph node type |
48 | | -from onnx import NodeProto # noqa |
| 48 | +from finn.transformation.util import group_inputs_by_category |
49 | 49 |
|
50 | 50 |
|
| 51 | +# Note: Old name kept for compatibility reasons but actually allows to absorb |
| 52 | +# any bias irrespective of signedness which might result in changed signedness |
| 53 | +# of the output type |
51 | 54 | class AbsorbSignBiasIntoMultiThreshold(Transformation): |
52 | 55 | """Absorb scalar bias originating from signed int export back into |
53 | 56 | MultiThreshold and re-evaluate the output datatype.""" |
54 | 57 |
|
55 | | - def apply(self, model): |
| 58 | + def apply(self, model: ModelWrapper): |
| 59 | + # Get the model graph out of the model wrapper object |
56 | 60 | graph = model.graph |
57 | | - node_ind = 0 |
| 61 | + # Keep track of whether the graph has been modified |
58 | 62 | graph_modified = False |
59 | | - for n in graph.node: |
60 | | - # search for (MultiThreshold, Add) pair |
61 | | - node_ind += 1 |
| 63 | + # Iterate all nodes in the graph keeping track of the index |
| 64 | + for index, node in enumerate(graph.node): |
| 65 | + # Only non-branching threshold operations are supported |
62 | 66 | if ( |
63 | | - n.op_type == "MultiThreshold" |
64 | | - and not model.is_fork_node(n) |
65 | | - and not model.is_join_node(n) |
| 67 | + node.op_type == "MultiThreshold" |
| 68 | + and not model.is_fork_node(node) |
| 69 | + and not model.is_join_node(node) |
66 | 70 | ): |
67 | | - consumer = model.find_consumer(n.output[0]) |
| 71 | + # We now we are not forking, so there is at most one consumer |
| 72 | + consumer = model.find_consumer(node.output[0]) |
| 73 | + # At the end of the graph we might have no consumer. If we have |
| 74 | + # one, only handle Adds, turn Sub into Add first... |
68 | 75 | if consumer is not None and consumer.op_type == "Add": |
69 | | - mt_node = n |
70 | | - add_node = consumer |
71 | | - threshold_name = mt_node.input[1] |
72 | | - add_weight_name = add_node.input[1] |
73 | | - T = model.get_initializer(threshold_name) |
74 | | - A = model.get_initializer(add_weight_name) |
75 | | - if (A is None) or (T is None): |
76 | | - warnings.warn("Threshold or add bias not constant, skipping") |
| 76 | + # Try to get the parameter tensor for the addition: Sanity |
| 77 | + # check whether this is present, even though we already |
| 78 | + # tested for non-joining |
| 79 | + bias = model.get_initializer(consumer.input[1]) |
| 80 | + |
| 81 | + # Warn and skip if there is no constant bias present |
| 82 | + if bias is None: |
| 83 | + warnings.warn( |
| 84 | + f"{self.__class__.__name__}: Bias not constant for" |
| 85 | + f" {consumer.name}, skipping." |
| 86 | + ) |
| 87 | + # Skip to next node, nothing changed so far, no need to |
| 88 | + # break here |
77 | 89 | continue |
78 | | - end_name = add_node.output[0] |
79 | | - # we can only absorb scalar adds |
80 | | - is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) |
81 | | - if not is_scalar: |
| 90 | + |
| 91 | + # Try to get the parameter tensor for the thresholds: Sanity |
| 92 | + # check whether this is present, even though we already |
| 93 | + # tested for non-joining |
| 94 | + thresholds = model.get_initializer(node.input[1]) |
| 95 | + |
| 96 | + # Warn and skip if there is no constant bias present |
| 97 | + if thresholds is None: |
| 98 | + warnings.warn( |
| 99 | + f"{self.__class__.__name__}: Thresholds not" |
| 100 | + f" constant for {node.name}, skipping." |
| 101 | + ) |
| 102 | + # Skip to next node, nothing changed so far, no need to |
| 103 | + # break here |
82 | 104 | continue |
83 | | - bias = A.flatten()[0] |
84 | | - # set MultiThreshold bias property |
85 | | - mt_inst = getCustomOp(mt_node) |
86 | | - bias += mt_inst.get_nodeattr("out_bias") |
87 | | - mt_inst.set_nodeattr("out_bias", bias) |
| 105 | + |
| 106 | + # Check whether the bias is as scalar as we cannot absorb |
| 107 | + # full tensors into node attributes |
| 108 | + if not (bias.ndim == 0 or all(x == 1 for x in bias.shape)): |
| 109 | + warnings.warn( |
| 110 | + f"{self.__class__.__name__}: Bias not scalar" |
| 111 | + f" for {consumer.name}, skipping." |
| 112 | + ) |
| 113 | + # Skip to next node, nothing changed so far, no need to |
| 114 | + # break here |
| 115 | + continue |
| 116 | + |
| 117 | + # Flatten effectively scalar bias tensors and extract to |
| 118 | + # have "plain" scalar |
| 119 | + bias = bias.flatten()[0] |
| 120 | + # CustomOp instance of the thresholding node required for |
| 121 | + # convenient attribute manipulation |
| 122 | + threshold_op = getCustomOp(node) |
| 123 | + # Shift the output bias of the thresholding operator |
| 124 | + out_bias = threshold_op.get_nodeattr("out_bias") + bias |
| 125 | + # Derive the new output range due to shifting the bias |
| 126 | + # Note: We count thresholds steps on top of the bias |
| 127 | + new_min = out_bias |
| 128 | + new_max = out_bias + thresholds.shape[-1] |
| 129 | + |
| 130 | + # Allows the signedness to change depending on the new |
| 131 | + # output range [new_min,new_max] |
| 132 | + if abs(new_min) > abs(new_max): |
| 133 | + odt = DataType.get_smallest_possible(new_min) |
| 134 | + else: |
| 135 | + odt = DataType.get_smallest_possible(new_max) |
| 136 | + |
| 137 | + # Check whether the new range can be represented with the |
| 138 | + # derived integer datatype |
| 139 | + if not (odt.allowed(new_max) and odt.allowed(new_min)): |
| 140 | + # Cannot be represented, warn and skip transforming |
| 141 | + warnings.warn( |
| 142 | + f"{self.__class__.__name__}: Cannot absorb bias" |
| 143 | + f" from {consumer.name} into {node.name}: {bias}" |
| 144 | + ) |
| 145 | + # Skip to the next candidate node |
| 146 | + continue |
| 147 | + |
| 148 | + # Remember the old datatype for some further checks and info |
| 149 | + old_odt = threshold_op.get_nodeattr("out_dtype") |
| 150 | + |
| 151 | + # Check whether the datatype changes as this is something |
| 152 | + # the "user" should be aware of |
| 153 | + if odt.name != old_odt: |
| 154 | + warnings.warn( |
| 155 | + f"{self.__class__.__name__}: Output datatype for" |
| 156 | + f" {node.name} changing from {old_odt} to {odt}" |
| 157 | + ) |
| 158 | + |
| 159 | + # Up until now we did not modify the nodes/grap, just did |
| 160 | + # some checks and derive the new bias and datatype. Start |
| 161 | + # inserting this back into the graph now... |
| 162 | + |
| 163 | + # Set new bias and datatype attributes into the threshold |
| 164 | + # operator |
| 165 | + threshold_op.set_nodeattr("out_bias", out_bias) |
| 166 | + threshold_op.set_nodeattr("out_dtype", odt.name) |
| 167 | + # Remove the bias operator and rewire the graph to skip the |
| 168 | + # now-missing node |
| 169 | + node.output[0] = consumer.output[0] |
| 170 | + graph.node.remove(consumer) |
| 171 | + # Update the datatype at the output of the threshold |
| 172 | + # operation |
| 173 | + model.set_tensor_datatype(node.output[0], odt) |
| 174 | + |
| 175 | + # Graph modified so we need to apply this transformation |
| 176 | + # again |
88 | 177 | graph_modified = True |
89 | | - # compute new DataType for MultiThreshold output |
90 | | - steps = T.shape[-1] |
91 | | - new_min = bias |
92 | | - new_max = steps + bias |
93 | | - odt = DataType.get_smallest_possible(steps).name.replace("UINT", "INT") |
94 | | - odt = DataType[odt] |
95 | | - assert odt.allowed(new_max) and odt.allowed( |
96 | | - new_min |
97 | | - ), """Could |
98 | | - not compute new MultiThreshold DataType (min = %d max = %d)""" % ( |
99 | | - new_min, |
100 | | - new_max, |
101 | | - ) |
102 | | - mt_inst.set_nodeattr("out_dtype", odt.name) |
103 | | - # remove Add node, rewire MultiThreshold |
104 | | - graph.node.remove(add_node) |
105 | | - mt_node.output[0] = end_name |
106 | | - # set datatype |
107 | | - model.set_tensor_datatype(end_name, odt) |
108 | | - if graph_modified: |
109 | | - model = model.transform(InferDataTypes()) |
110 | | - return (model, graph_modified) |
| 178 | + # Better break now to clean up and recover annotations first |
| 179 | + break |
| 180 | + # As we might have changes types and removed nodes better redo some |
| 181 | + # annotations |
| 182 | + model = model.transform(InferDataTypes()) |
| 183 | + model = model.transform(InferShapes()) |
| 184 | + # Transformed model and indication whether the transformation should be |
| 185 | + # applied again |
| 186 | + return model, graph_modified |
111 | 187 |
|
112 | 188 |
|
113 | 189 | # Groups inputs by categories, i.e., groups dynamic inputs first, followed by |
|
0 commit comments