diff --git a/docs/overview.rst b/docs/overview.rst index 8e2002d7..161d1e49 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -45,6 +45,20 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. +Custom ops are automatically discovered through Python module namespaces. +Simply define your CustomOp subclass in the appropriate domain module +(e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically +available through ``getCustomOp``. + +For dynamic registration and querying, use the registry functions: + +* ``getCustomOp(node)`` - Get a custom op instance from an ONNX node +* ``is_custom_op(domain, op_type=None)`` - Check if a specific op or domain has custom ops +* ``add_op_to_domain(domain, op_class)`` - Register an op at runtime (for testing) +* ``get_ops_in_domain(domain)`` - List all ops available in a domain +* ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path +* ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain + Custom ONNX Execution Flow ========================== diff --git a/notebooks/3_custom_op.ipynb b/notebooks/3_custom_op.ipynb index d0cd10fd..1b822163 100644 --- a/notebooks/3_custom_op.ipynb +++ b/notebooks/3_custom_op.ipynb @@ -129,35 +129,26 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "To make sure our custom op is available, it needs to be registered. The best practice for this is to create a submodule under `qonnx.custom_op` which includes a `custom_op` dictionary that maps strings (op names) to classes (op implementations). Since we're in a Jupyter notebook we'll just hijack it at runtime like this:" - ] + "source": "To make sure our custom op is available, we need to add it to the domain. For production code, you would place your CustomOp class directly in the appropriate module file (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import qonnx.custom_op.general as general\n", - "general.custom_op[\"MyPythonPowerOp\"] = MyPythonPowerOp" - ] + "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain\nadd_op_to_domain(\"qonnx.custom_op.general\", MyPythonPowerOp)", + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "We can see which custom ops are registered under this submodule by looking at the dictionary:" - ] + "source": "We can see which custom ops are available in a domain by using the registry function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "general.custom_op" - ] + "source": "from qonnx.custom_op.registry import get_ops_in_domain, is_custom_op\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nprint(f\"MyPythonPowerOp available: {is_custom_op('qonnx.custom_op.general', 'MyPythonPowerOp')}\")", + "execution_count": null }, { "cell_type": "markdown", @@ -462,17 +453,10 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# register our new op\n", - "general.custom_op[\"MyMixedPowerOp\"] = MyMixedPowerOp\n", - "\n", - "# make graph with new op\n", - "mixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\n", - "mixedop_graph.graph.node" - ] + "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node", + "execution_count": null }, { "cell_type": "markdown", @@ -744,4 +728,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index c08278d9..db7c6f87 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,7 +49,7 @@ install_requires = importlib-metadata attrs>=22.2.0 clize>=5.0.1 - protobuf==3.20.3 + protobuf>=3.20.3 bitstring>=3.1.7 numpy>=1.24.1 onnx>=1.13.0 @@ -106,6 +106,10 @@ console_scripts = qonnx-tensor-stats = qonnx.analysis.tensor_stats:main pytest_randomly.random_seeder = qonnx = qonnx.util.random_reseed:reseed +# entry points for custom op modules +qonnx_custom_ops = + qonnx = qonnx.custom_op.general + qonnx_channels_last = qonnx.custom_op.channels_last # Add here console scripts like: # console_scripts = # script_name = qonnx.module:function diff --git a/src/qonnx/__init__.py b/src/qonnx/__init__.py index e69de29b..217648b8 100644 --- a/src/qonnx/__init__.py +++ b/src/qonnx/__init__.py @@ -0,0 +1,25 @@ +"""QONNX package initialization.""" + +import warnings +from importlib import metadata + + +def _load_custom_op_entry_points(): + """Import modules registered under the ``qonnx_custom_ops`` entry point.""" + + try: + eps = metadata.entry_points() + if hasattr(eps, "select"): + eps = eps.select(group="qonnx_custom_ops") + else: + eps = eps.get("qonnx_custom_ops", []) + for ep in eps: + try: + ep.load() + except Exception as e: # pragma: no cover - import failure warning + warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") + except Exception as e: # pragma: no cover - metadata failure warning + warnings.warn(f"Failed to query custom op entry points: {e}") + + +_load_custom_op_entry_points() diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index e69de29b..7c38a8df 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 Xilinx, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Domain aliases are automatically handled by the registry +# The onnx.brevitas -> qonnx.custom_op.general mapping is built into the registry \ No newline at end of file diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 77a048e7..eceb9783 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -8,4 +8,4 @@ "Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization, -} \ No newline at end of file +} diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 2f3896de..6c38ada7 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -54,4 +54,4 @@ "Trunc": Trunc, "BipolarQuant": BipolarQuant, "FloatQuant": FloatQuant, -} \ No newline at end of file +} diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 3d448dc3..a7356a8f 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,12 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.general.intquant import IntQuant as Quant +# Import IntQuant to create alias +from qonnx.custom_op.general.intquant import IntQuant + +# Re-export functions from intquant for backward compatibility from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode -Quant = Quant -quant = quant -max_int = max_int -min_int = min_int -resolve_rounding_mode = resolve_rounding_mode +# Create alias for backward compatibility - Quant is just IntQuant +Quant = IntQuant \ No newline at end of file diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..5d585d0c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -44,7 +44,8 @@ from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly -_channelsLast_node_types = list(channels_last.custom_op.keys()) +# use the list of exported op names from the channels_last package +_channelsLast_node_types = list(channels_last.__all__) # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..614df416 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -69,6 +69,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + if hasattr(node, "metadata_props"): + inp_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + if hasattr(node, "metadata_props"): + inp_zeropt_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +112,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_zeropt_node.metadata_props.extend(node.metadata_props) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +133,8 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..1298f3d6 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -76,6 +76,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +100,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +113,8 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + if hasattr(n, "metadata_props"): + matMul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +150,8 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +183,8 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +206,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/util/config.py b/src/qonnx/util/config.py index 63661862..2f6383d3 100644 --- a/src/qonnx/util/config.py +++ b/src/qonnx/util/config.py @@ -27,13 +27,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import onnx from qonnx.custom_op.registry import getCustomOp - -def extract_model_config_to_json(model, json_filename, attr_names_to_extract): - """Create a json file with layer name -> attribute mappings extracted from the - model. The created json file can be later applied on a model with +# update this code to handle export configs from subgraphs +# where the subgraph is found in a node's attribute as a graph type +def extract_model_config(model, attr_names_to_extract): + """Create a dictionary with layer name -> attribute mappings extracted from the + model. The created dictionary can be later applied on a model with qonnx.transform.general.ApplyConfig.""" cfg = dict() @@ -41,12 +43,22 @@ def extract_model_config_to_json(model, json_filename, attr_names_to_extract): for n in model.graph.node: oi = getCustomOp(n) layer_dict = dict() - for attr in attr_names_to_extract: - try: - layer_dict[attr] = oi.get_nodeattr(attr) - except AttributeError: - pass + for attr in n.attribute: + if attr.type == onnx.AttributeProto.GRAPH: # Graph type + # If the attribute is a graph, we need to extract the attributes from the subgraph + cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract)) + elif attr.name in attr_names_to_extract: + # If the attribute name is in the list, we can add it directly + layer_dict[attr.name] = oi.get_nodeattr(attr.name) if len(layer_dict) > 0: cfg[n.name] = layer_dict + return cfg + + +def extract_model_config_to_json(model, json_filename, attr_names_to_extract): + """Create a json file with layer name -> attribute mappings extracted from the + model. The created json file can be later applied on a model with + qonnx.transform.general.ApplyConfig.""" + with open(json_filename, "w") as f: - json.dump(cfg, f, indent=2) + json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2) diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index cde5a321..ac4f7a5c 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -29,10 +29,9 @@ import numpy as np import onnx.parser as oprs -import qonnx.custom_op.general as general from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp +from qonnx.custom_op.registry import getCustomOp, add_op_to_domain class AttrTestOp(CustomOp): @@ -60,7 +59,9 @@ def verify_node(self): def test_attr(): - general.custom_op["AttrTestOp"] = AttrTestOp + # Add the test op to the domain + add_op_to_domain("qonnx.custom_op.general", AttrTestOp) + ishp = (1, 10) wshp = (1, 3) oshp = wshp @@ -87,6 +88,8 @@ def test_attr(): """ model = oprs.parse_model(input) model = ModelWrapper(model) + + # Now getCustomOp should find it through the manual registry inst = getCustomOp(model.graph.node[0]) w_prod = inst.get_nodeattr("tensor_attr") diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py index 24e64b4f..30382c64 100644 --- a/tests/transformation/test_channelslast.py +++ b/tests/transformation/test_channelslast.py @@ -43,11 +43,11 @@ MoveTransposePastFork, RemoveConsecutiveChanFirstAndChanLastTrafos, ) +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import is_finn_op from qonnx.util.test import download_model, get_golden_in_and_output, test_model_details from qonnx.util.to_channels_last import to_channels_last @@ -126,7 +126,7 @@ def analysis_test_for_left_transposes(model, test_model, make_input_channels_las def verify_all_nodes(model): result = dict() for n in model.graph.node: - if is_finn_op(n.domain): + if is_custom_op(n.domain): n_instance = getCustomOp(n) verify_result = n_instance.verify_node() result[n.name] = verify_result