|
30 | 30 | import numpy as np |
31 | 31 | import qonnx.core.data_layout as DataLayout |
32 | 32 | import warnings |
33 | | -from onnx import TensorProto, helper |
| 33 | +from onnx import NodeProto, TensorProto, helper |
34 | 34 | from qonnx.core.datatype import DataType |
35 | 35 | from qonnx.core.modelwrapper import ModelWrapper |
36 | 36 | from qonnx.custom_op.registry import getCustomOp |
|
41 | 41 | from qonnx.util.basic import get_by_name |
42 | 42 | from qonnx.util.onnx import nchw_to_nhwc |
43 | 43 |
|
| 44 | +# Module containing specializations of elementwise binary operations |
| 45 | +import finn.custom_op.fpgadataflow.elementwise_binary as elementwise_binary |
| 46 | + |
44 | 47 | # Base class for all FINN custom ops, here just used for type-hinting |
45 | 48 | from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp |
46 | 49 |
|
@@ -1831,6 +1834,133 @@ def apply(self, model): |
1831 | 1834 | return (model, graph_modified) |
1832 | 1835 |
|
1833 | 1836 |
|
| 1837 | +# Lifts scalar to rank-1 tensor |
| 1838 | +def lift_to_rank1(name: str, model: ModelWrapper): |
| 1839 | + # Scalars have a shape of lengths zero |
| 1840 | + if len(model.get_tensor_shape(name)) == 0: |
| 1841 | + # Lift shape to rank-1 tensor with single element |
| 1842 | + model.set_tensor_shape(name, [1]) |
| 1843 | + # Check whether this tensor has an initializer |
| 1844 | + if (tensor := model.get_initializer(name)) is not None: |
| 1845 | + # Set new initializer tensor of shape [1] |
| 1846 | + model.set_initializer(name, tensor.reshape(1)) |
| 1847 | + |
| 1848 | + |
| 1849 | +# Converts supported elementwise binary operations to their FINN custom |
| 1850 | +# operation |
| 1851 | +class InferElementwiseBinaryOperation(Transformation): |
| 1852 | + # Filter function to filter out the last elementwise Mul operation, |
| 1853 | + # typically corresponding to output de-quantization, which should happen |
| 1854 | + # off-chip |
| 1855 | + @staticmethod |
| 1856 | + def reject_output_dequant(model: ModelWrapper, node: NodeProto): |
| 1857 | + # The operator must be a Mul and have no successor nodes |
| 1858 | + if node.op_type == "Mul" and not model.find_direct_successors(node): |
| 1859 | + # If the output is a floating-point tensors, reject this |
| 1860 | + if model.get_tensor_datatype(node.output[0]) == "FLOAT32": |
| 1861 | + # Filter False rejects this node |
| 1862 | + return False |
| 1863 | + # Filter True accepts this node |
| 1864 | + return True |
| 1865 | + |
| 1866 | + # Filter function to filter out any operation involving any floating-point |
| 1867 | + # tensor |
| 1868 | + @staticmethod |
| 1869 | + def reject_floats(model: ModelWrapper, node: NodeProto): |
| 1870 | + # Check for any input being floating-point |
| 1871 | + if any(model.get_tensor_datatype(x) == "FLOAT32" for x in node.input): |
| 1872 | + # Filter False rejects this node |
| 1873 | + return False |
| 1874 | + # Check for any output being floating-point |
| 1875 | + if any(model.get_tensor_datatype(x) == "FLOAT32" for x in node.output): |
| 1876 | + # Filter False rejects this node |
| 1877 | + return False |
| 1878 | + # Filter True accepts this node |
| 1879 | + return True |
| 1880 | + |
| 1881 | + # Initializes the transformation method with an optional filter function |
| 1882 | + def __init__(self, _filter=None): |
| 1883 | + # Initialize the base class Transformation object |
| 1884 | + super().__init__() |
| 1885 | + # Register the filter function as attribute |
| 1886 | + self._filter = _filter if _filter is not None else lambda *_: True |
| 1887 | + |
| 1888 | + # Applies the transform to a whole model graph |
| 1889 | + def apply(self, model: ModelWrapper): # noqa |
| 1890 | + # Get the model graph out of the model wrapper object |
| 1891 | + graph = model.graph |
| 1892 | + # Keep track of whether the graph has been modified |
| 1893 | + graph_modified = False |
| 1894 | + # Iterate all nodes in the graph keeping track of the index |
| 1895 | + for index, node in enumerate(graph.node): |
| 1896 | + # Skip transforming nodes rejected by the filter |
| 1897 | + if not self._filter(model, node): |
| 1898 | + continue |
| 1899 | + # If a custom operation with corresponding name is implemented in |
| 1900 | + # the module, this operator is supported for conversion |
| 1901 | + if f"Elementwise{node.op_type}" in dir(elementwise_binary): |
| 1902 | + # Transplant this operator into our FINN domain |
| 1903 | + node.domain = "finn.custom_op.fpgadataflow" |
| 1904 | + # Adapt the op-type prefixing it with Elementwise |
| 1905 | + # TODO: Consider dropping the prefix? |
| 1906 | + node.op_type = f"Elementwise{node.op_type}" |
| 1907 | + # Now we can get the CustomOp wrapper instance providing easier |
| 1908 | + # attribute access |
| 1909 | + inst: HWCustomOp = getCustomOp(node) |
| 1910 | + # Set the backend attribute to mark this an operation supported |
| 1911 | + # to be implemented on an FPGA by FINN |
| 1912 | + inst.set_nodeattr("backend", "fpgadataflow") |
| 1913 | + # Need to "lift" potential scalar inputs to rank-1 tensors |
| 1914 | + lift_to_rank1(node.input[0], model) |
| 1915 | + lift_to_rank1(node.input[1], model) |
| 1916 | + |
| 1917 | + # fmt: off |
| 1918 | + # Disable formatter. This is deliberately formatted to stay |
| 1919 | + # within 80 characters per line. Black, however, formats some |
| 1920 | + # lines going beyond this. |
| 1921 | + |
| 1922 | + # Insert data type attributes from "context" into the CustomOp |
| 1923 | + # node |
| 1924 | + # TODO: Find a way to handle this via data type inference? |
| 1925 | + inst.set_nodeattr( |
| 1926 | + "lhs_dtype", str(model.get_tensor_datatype(node.input[0])) |
| 1927 | + ) |
| 1928 | + inst.set_nodeattr( |
| 1929 | + "rhs_dtype", str(model.get_tensor_datatype(node.input[1])) |
| 1930 | + ) |
| 1931 | + inst.set_nodeattr( |
| 1932 | + "out_dtype", str(model.get_tensor_datatype(node.output[0])) |
| 1933 | + ) |
| 1934 | + # Insert shape attributes from "context" into the CustomOp node |
| 1935 | + # TODO: Find a way to handle this via shape inference? |
| 1936 | + inst.set_nodeattr( |
| 1937 | + "lhs_shape", model.get_tensor_shape(node.input[0]) |
| 1938 | + ) |
| 1939 | + inst.set_nodeattr( |
| 1940 | + "rhs_shape", model.get_tensor_shape(node.input[1]) |
| 1941 | + ) |
| 1942 | + inst.set_nodeattr( |
| 1943 | + "out_shape", model.get_tensor_shape(node.output[0]) |
| 1944 | + ) |
| 1945 | + |
| 1946 | + # fmt: on |
| 1947 | + |
| 1948 | + # Consider the graph to be modified, triggering exhaustive |
| 1949 | + # re-application of this transformation |
| 1950 | + graph_modified = True |
| 1951 | + # Exiting here triggers type and shape inference and cleanup |
| 1952 | + # after each transformed node. This helps QONNX to behave |
| 1953 | + # better / more consistent in certain cases... |
| 1954 | + break |
| 1955 | + # Re-do shape and data type annotations after potential changes to the |
| 1956 | + # model graph |
| 1957 | + model = model.transform(InferShapes()) |
| 1958 | + model = model.transform(InferDataTypes()) |
| 1959 | + # Return the transformed model and indicate whether the graph actually |
| 1960 | + # has been transformed |
| 1961 | + return model, graph_modified |
| 1962 | + |
| 1963 | + |
1834 | 1964 | # Converts the Squeeze operation to the corresponding FINN custom operation |
1835 | 1965 | class InferSqueeze(Transformation): |
1836 | 1966 | # Applies the transform to a whole model graph |
|
0 commit comments