Skip to content

Commit a0b9007

Browse files
authored
Merge pull request #10 from eki-project/elementwise-binary
Introduce support for generic elementwise binary operations
2 parents fa57176 + af99d03 commit a0b9007

File tree

13 files changed

+2782
-28
lines changed

13 files changed

+2782
-28
lines changed

.isort.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ sections=FUTURE,STDLIB,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
99
default_section=THIRDPARTY
1010
multi_line_output=3
1111
profile=black
12+
ignore_comments=true
13+
ignore_whitespace=true
14+
honor_noqa=true

src/finn/custom_op/fpgadataflow/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,17 @@ def register_custom_op(cls):
5151
# Disable linting from here, as all import will be flagged E402 and maybe F401
5252

5353

54+
# Import the submodule containing specializations of ElementwiseBinaryOperation
55+
# Note: This will automatically register all decorated classes into this domain
56+
import finn.custom_op.fpgadataflow.elementwise_binary
57+
5458
# Import the submodule containing the Squeeze operation
5559
# Note: This will automatically register all decorated classes into this domain
5660
import finn.custom_op.fpgadataflow.squeeze
5761

5862
# Import the submodule containing the Unsqueeze operation
5963
import finn.custom_op.fpgadataflow.unsqueeze
64+
6065
from finn.custom_op.fpgadataflow.addstreams import AddStreams
6166
from finn.custom_op.fpgadataflow.channelwise_op import ChannelwiseOp
6267
from finn.custom_op.fpgadataflow.concat import StreamingConcat

src/finn/custom_op/fpgadataflow/elementwise_binary.py

Lines changed: 809 additions & 0 deletions
Large diffs are not rendered by default.

src/finn/custom_op/fpgadataflow/hls/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,17 @@ def register_custom_op(cls):
5454
# flake8: noqa
5555
# Disable linting from here, as all import will be flagged E402 and maybe F401
5656

57+
# Import the submodule containing specializations of ElementwiseBinaryOperation
58+
# Note: This will automatically register all decorated classes into this domain
59+
import finn.custom_op.fpgadataflow.hls.elementwise_binary_hls
60+
5761
# Import the submodule containing the specialization of the Squeeze operation
5862
# Note: This will automatically register all decorated classes into this domain
5963
import finn.custom_op.fpgadataflow.hls.squeeze_hls
6064

6165
# Import the submodule containing the specialization of the Unsqueeze operation
6266
import finn.custom_op.fpgadataflow.hls.unsqueeze_hls
67+
6368
from finn.custom_op.fpgadataflow.hls.addstreams_hls import AddStreams_hls
6469
from finn.custom_op.fpgadataflow.hls.channelwise_op_hls import ChannelwiseOp_hls
6570
from finn.custom_op.fpgadataflow.hls.checksum_hls import CheckSum_hls

src/finn/custom_op/fpgadataflow/hls/elementwise_binary_hls.py

Lines changed: 766 additions & 0 deletions
Large diffs are not rendered by default.

src/finn/custom_op/fpgadataflow/templates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
# template for single node execution
3131
docompute_template = """
32+
#define HLS_CONSTEXPR_ENABLE
3233
#define AP_INT_MAX_W $AP_INT_MAX_W$
3334
#include "cnpy.h"
3435
#include "npy2apintstream.hpp"
@@ -108,6 +109,7 @@
108109

109110
# cpp file
110111
ipgen_template = """
112+
#define HLS_CONSTEXPR_ENABLE
111113
#define AP_INT_MAX_W $AP_INT_MAX_W$
112114
113115
#include "bnn-library.h"

src/finn/custom_op/fpgadataflow/thresholding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def execute_node(self, context, graph):
257257
if act == DataType["BIPOLAR"]:
258258
# binary to bipolar
259259
y = 2 * y - 1
260-
context[node.output[0]] = y
260+
context[node.output[0]] = y.astype(np.float32)
261261

262262
def calc_tmem(self):
263263
"""Calculates and returns TMEM."""

src/finn/transformation/fpgadataflow/convert_to_hw_layers.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import numpy as np
3131
import qonnx.core.data_layout as DataLayout
3232
import warnings
33-
from onnx import TensorProto, helper
33+
from onnx import NodeProto, TensorProto, helper
3434
from qonnx.core.datatype import DataType
3535
from qonnx.core.modelwrapper import ModelWrapper
3636
from qonnx.custom_op.registry import getCustomOp
@@ -41,6 +41,9 @@
4141
from qonnx.util.basic import get_by_name
4242
from qonnx.util.onnx import nchw_to_nhwc
4343

44+
# Module containing specializations of elementwise binary operations
45+
import finn.custom_op.fpgadataflow.elementwise_binary as elementwise_binary
46+
4447
# Base class for all FINN custom ops, here just used for type-hinting
4548
from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp
4649

@@ -1831,6 +1834,133 @@ def apply(self, model):
18311834
return (model, graph_modified)
18321835

18331836

1837+
# Lifts scalar to rank-1 tensor
1838+
def lift_to_rank1(name: str, model: ModelWrapper):
1839+
# Scalars have a shape of lengths zero
1840+
if len(model.get_tensor_shape(name)) == 0:
1841+
# Lift shape to rank-1 tensor with single element
1842+
model.set_tensor_shape(name, [1])
1843+
# Check whether this tensor has an initializer
1844+
if (tensor := model.get_initializer(name)) is not None:
1845+
# Set new initializer tensor of shape [1]
1846+
model.set_initializer(name, tensor.reshape(1))
1847+
1848+
1849+
# Converts supported elementwise binary operations to their FINN custom
1850+
# operation
1851+
class InferElementwiseBinaryOperation(Transformation):
1852+
# Filter function to filter out the last elementwise Mul operation,
1853+
# typically corresponding to output de-quantization, which should happen
1854+
# off-chip
1855+
@staticmethod
1856+
def reject_output_dequant(model: ModelWrapper, node: NodeProto):
1857+
# The operator must be a Mul and have no successor nodes
1858+
if node.op_type == "Mul" and not model.find_direct_successors(node):
1859+
# If the output is a floating-point tensors, reject this
1860+
if model.get_tensor_datatype(node.output[0]) == "FLOAT32":
1861+
# Filter False rejects this node
1862+
return False
1863+
# Filter True accepts this node
1864+
return True
1865+
1866+
# Filter function to filter out any operation involving any floating-point
1867+
# tensor
1868+
@staticmethod
1869+
def reject_floats(model: ModelWrapper, node: NodeProto):
1870+
# Check for any input being floating-point
1871+
if any(model.get_tensor_datatype(x) == "FLOAT32" for x in node.input):
1872+
# Filter False rejects this node
1873+
return False
1874+
# Check for any output being floating-point
1875+
if any(model.get_tensor_datatype(x) == "FLOAT32" for x in node.output):
1876+
# Filter False rejects this node
1877+
return False
1878+
# Filter True accepts this node
1879+
return True
1880+
1881+
# Initializes the transformation method with an optional filter function
1882+
def __init__(self, _filter=None):
1883+
# Initialize the base class Transformation object
1884+
super().__init__()
1885+
# Register the filter function as attribute
1886+
self._filter = _filter if _filter is not None else lambda *_: True
1887+
1888+
# Applies the transform to a whole model graph
1889+
def apply(self, model: ModelWrapper): # noqa
1890+
# Get the model graph out of the model wrapper object
1891+
graph = model.graph
1892+
# Keep track of whether the graph has been modified
1893+
graph_modified = False
1894+
# Iterate all nodes in the graph keeping track of the index
1895+
for index, node in enumerate(graph.node):
1896+
# Skip transforming nodes rejected by the filter
1897+
if not self._filter(model, node):
1898+
continue
1899+
# If a custom operation with corresponding name is implemented in
1900+
# the module, this operator is supported for conversion
1901+
if f"Elementwise{node.op_type}" in dir(elementwise_binary):
1902+
# Transplant this operator into our FINN domain
1903+
node.domain = "finn.custom_op.fpgadataflow"
1904+
# Adapt the op-type prefixing it with Elementwise
1905+
# TODO: Consider dropping the prefix?
1906+
node.op_type = f"Elementwise{node.op_type}"
1907+
# Now we can get the CustomOp wrapper instance providing easier
1908+
# attribute access
1909+
inst: HWCustomOp = getCustomOp(node)
1910+
# Set the backend attribute to mark this an operation supported
1911+
# to be implemented on an FPGA by FINN
1912+
inst.set_nodeattr("backend", "fpgadataflow")
1913+
# Need to "lift" potential scalar inputs to rank-1 tensors
1914+
lift_to_rank1(node.input[0], model)
1915+
lift_to_rank1(node.input[1], model)
1916+
1917+
# fmt: off
1918+
# Disable formatter. This is deliberately formatted to stay
1919+
# within 80 characters per line. Black, however, formats some
1920+
# lines going beyond this.
1921+
1922+
# Insert data type attributes from "context" into the CustomOp
1923+
# node
1924+
# TODO: Find a way to handle this via data type inference?
1925+
inst.set_nodeattr(
1926+
"lhs_dtype", str(model.get_tensor_datatype(node.input[0]))
1927+
)
1928+
inst.set_nodeattr(
1929+
"rhs_dtype", str(model.get_tensor_datatype(node.input[1]))
1930+
)
1931+
inst.set_nodeattr(
1932+
"out_dtype", str(model.get_tensor_datatype(node.output[0]))
1933+
)
1934+
# Insert shape attributes from "context" into the CustomOp node
1935+
# TODO: Find a way to handle this via shape inference?
1936+
inst.set_nodeattr(
1937+
"lhs_shape", model.get_tensor_shape(node.input[0])
1938+
)
1939+
inst.set_nodeattr(
1940+
"rhs_shape", model.get_tensor_shape(node.input[1])
1941+
)
1942+
inst.set_nodeattr(
1943+
"out_shape", model.get_tensor_shape(node.output[0])
1944+
)
1945+
1946+
# fmt: on
1947+
1948+
# Consider the graph to be modified, triggering exhaustive
1949+
# re-application of this transformation
1950+
graph_modified = True
1951+
# Exiting here triggers type and shape inference and cleanup
1952+
# after each transformed node. This helps QONNX to behave
1953+
# better / more consistent in certain cases...
1954+
break
1955+
# Re-do shape and data type annotations after potential changes to the
1956+
# model graph
1957+
model = model.transform(InferShapes())
1958+
model = model.transform(InferDataTypes())
1959+
# Return the transformed model and indicate whether the graph actually
1960+
# has been transformed
1961+
return model, graph_modified
1962+
1963+
18341964
# Converts the Squeeze operation to the corresponding FINN custom operation
18351965
class InferSqueeze(Transformation):
18361966
# Applies the transform to a whole model graph

src/finn/transformation/fpgadataflow/set_folding.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,17 @@
2727
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2828
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30+
# Inspect information on Python objects like modules
31+
import inspect
3032
import numpy as np
3133
import warnings
3234
from qonnx.custom_op.registry import getCustomOp
3335
from qonnx.transformation.base import Transformation
3436
from qonnx.transformation.general import GiveUniqueNodeNames
3537

38+
# Import the elementwise binary operation module to extract names of all
39+
# specializations (which require PE parallelism to be configured)
40+
import finn.custom_op.fpgadataflow.hls.elementwise_binary_hls as elementwise_binary_hls
3641
from finn.analysis.fpgadataflow.dataflow_performance import dataflow_performance
3742
from finn.transformation.fpgadataflow.annotate_cycles import AnnotateCycles
3843
from finn.util.fpgadataflow import is_hls_node, is_rtl_node
@@ -44,6 +49,15 @@ def divisors(num):
4449
yield x
4550

4651

52+
# Find the op-type names for all HLS specializations of elementwise binary
53+
# operations
54+
ELEMENTWISE_BINARY_OPS = [
55+
op_type
56+
for op_type, cls in inspect.getmembers(elementwise_binary_hls, inspect.isclass)
57+
if issubclass(cls, elementwise_binary_hls.ElementwiseBinaryOperation_hls)
58+
]
59+
60+
4761
class SetFolding(Transformation):
4862
"""Attempt to set parallelism attributes in all nodes to meet a specific
4963
target expressed as cycles per frame target_cycles_per_frame. For each
@@ -106,6 +120,7 @@ def apply(self, model):
106120
"GlobalAccPool_hls",
107121
"Thresholding_hls",
108122
"Thresholding_rtl",
123+
*ELEMENTWISE_BINARY_OPS,
109124
"Squeeze_hls",
110125
"Unsqueeze_hls",
111126
]
@@ -157,7 +172,16 @@ def apply(self, model):
157172
# increase PE until target met or reached max_pe
158173
self.optimize_attribute_val(node_inst, max_pe, "PE")
159174
elif op_type in pe_ops:
160-
max_pe = node_inst.get_nodeattr("NumChannels")
175+
# Note: Keep original behavior for all custom-ops defining the
176+
# NumChannels attribute as it is
177+
try:
178+
max_pe = node_inst.get_nodeattr("NumChannels")
179+
# Note: Some of the recent additions do not define the
180+
# NumChannels attribute
181+
except AttributeError:
182+
# We can extract the channels from the normal, i.e., not
183+
# folded, shape of the input in these cases
184+
max_pe = node_inst.get_normal_input_shape()[-1]
161185
self.optimize_attribute_val(node_inst, max_pe, "PE")
162186
elif op_type == "LabelSelect_hls":
163187
max_pe = node_inst.get_nodeattr("Labels")

src/finn/transformation/qonnx/fold_quant_weights.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def apply(self, model):
149149
mul_tensor = helper.make_tensor_value_info(
150150
model.make_new_valueinfo_name(),
151151
TensorProto.FLOAT,
152-
mul_shape,
152+
mul_shape, # Note: This shape is known exactly as
153+
# it is an initializer with known shape
153154
)
154155
graph.value_info.append(mul_tensor)
155156
model.set_initializer(mul_tensor.name, scale)
@@ -168,7 +169,9 @@ def apply(self, model):
168169
act_mul_tensor = helper.make_tensor_value_info(
169170
model.make_new_valueinfo_name(),
170171
TensorProto.FLOAT,
171-
output_shape,
172+
None, # Note: Explicitly delete the shape
173+
# annotation to be redone by the next shape
174+
# inference
172175
)
173176
graph.value_info.append(act_mul_tensor)
174177
successor.output[0] = act_mul_tensor.name
@@ -186,19 +189,37 @@ def apply(self, model):
186189
div_tensor = helper.make_tensor_value_info(
187190
model.make_new_valueinfo_name(),
188191
TensorProto.FLOAT,
189-
mul_shape,
192+
None, # Note: Explicitly delete the shape
193+
# annotation to be redone by the next shape
194+
# inference
190195
)
191196
graph.value_info.append(div_tensor)
192197
model.set_initializer(div_tensor.name, scale)
193198

194-
succ_input_name = successor.input[0]
199+
# Detect which input of the add-like successor is
200+
# fed by the quantizer node to select the other
201+
# branch to insert the scale factor
202+
if successor.input[0] == node_out:
203+
succ_input_name = successor.input[1]
204+
else:
205+
succ_input_name = successor.input[0]
206+
195207
act_mul_tensor = helper.make_tensor_value_info(
196208
model.make_new_valueinfo_name(),
197209
TensorProto.FLOAT,
198-
output_shape,
210+
None, # Note: Explicitly delete the shape
211+
# annotation to be redone by the next shape
212+
# inference
199213
)
200214
graph.value_info.append(act_mul_tensor)
201-
successor.input[0] = act_mul_tensor.name
215+
216+
# Detect which input of the add-like successor is
217+
# fed by the quantizer node to select the other
218+
# branch to insert the scale factor
219+
if successor.input[0] == node_out:
220+
successor.input[1] = act_mul_tensor.name
221+
else:
222+
successor.input[0] = act_mul_tensor.name
202223

203224
div_node = helper.make_node(
204225
"Div",
@@ -210,6 +231,8 @@ def apply(self, model):
210231
# remove old node
211232
graph.node.remove(n)
212233
graph_modified = True
234+
# Note: Running shape inference is necessary as shape
235+
# annotations have been deleted above
213236
model = model.transform(InferShapes())
214237
return (model, graph_modified)
215238
return (model, graph_modified)

0 commit comments

Comments
 (0)