Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d6d95c5
Fix channels_last transformation for new registry
tafk7 Jun 16, 2025
858cf56
Add legacy domain fallback test
tafk7 Jun 16, 2025
5036a7a
Remove debug output from old domain test
tafk7 Jun 16, 2025
fdaea24
Merge pull request #1 from tafk7/codex/analyze-and-redesign-customop-…
tafk7 Jun 16, 2025
dfc4bd8
Add alternative customop registration decorator
tafk7 Jun 16, 2025
8fa8463
Merge remote-tracking branch 'upstream/main' into custom/brainsmith
auphelia Jun 20, 2025
e59e558
Added passthrough Quant class
tafk7 Jun 20, 2025
5dfb746
Merge pull request #1 from tafk7/custom/brainsmith-registration-fix
auphelia Jun 23, 2025
30df133
Bring back lost changes from custom/brainsmith branch
auphelia Jun 23, 2025
66b4c68
Merge pull request #195 from auphelia/custom/brainsmith
maltanar Jun 23, 2025
dad06c7
Refined domain-based registration
tafk7 Jul 8, 2025
f6806f6
Refined custom_op registration
tafk7 Jul 16, 2025
f7ab4b5
Dependency resolution
tafk7 Jul 16, 2025
d08c33d
help multithreshold handle 3-dim more efficiently
Jul 16, 2025
d76507a
update extract model config to export config for subgraphs
Jul 17, 2025
fa3e0a8
Removed decorators in favor of pure domain
tafk7 Jul 18, 2025
68346e3
Circular import fix
tafk7 Jul 18, 2025
93fd8d0
Added brainsmith to hide finn ops
tafk7 Jul 18, 2025
9153395
Move to namespace-based domain registration
tafk7 Jul 24, 2025
f2c4ccd
refactor: migrate registry to thread-safe, cache-based architecture
tafk7 Oct 19, 2025
8572cbb
Merge remote-tracking branch 'origin/main' into custom/bransmith_merg…
Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ Custom Operations/Nodes

QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`.

Custom ops are automatically discovered through Python module namespaces.
Simply define your CustomOp subclass in the appropriate domain module
(e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically
available through ``getCustomOp``.

For dynamic registration and querying, use the registry functions:

* ``getCustomOp(node)`` - Get a custom op instance from an ONNX node
* ``is_custom_op(domain, op_type=None)`` - Check if a specific op or domain has custom ops
* ``add_op_to_domain(domain, op_class)`` - Register an op at runtime (for testing)
* ``get_ops_in_domain(domain)`` - List all ops available in a domain
* ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path
* ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain


Custom ONNX Execution Flow
==========================
Expand Down
34 changes: 9 additions & 25 deletions notebooks/3_custom_op.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,26 @@
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make sure our custom op is available, it needs to be registered. The best practice for this is to create a submodule under `qonnx.custom_op` which includes a `custom_op` dictionary that maps strings (op names) to classes (op implementations). Since we're in a Jupyter notebook we'll just hijack it at runtime like this:"
]
"source": "To make sure our custom op is available, we need to add it to the domain. For production code, you would place your CustomOp class directly in the appropriate module file (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import qonnx.custom_op.general as general\n",
"general.custom_op[\"MyPythonPowerOp\"] = MyPythonPowerOp"
]
"source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain\nadd_op_to_domain(\"qonnx.custom_op.general\", MyPythonPowerOp)",
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see which custom ops are registered under this submodule by looking at the dictionary:"
]
"source": "We can see which custom ops are available in a domain by using the registry function:"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"general.custom_op"
]
"source": "from qonnx.custom_op.registry import get_ops_in_domain, is_custom_op\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nprint(f\"MyPythonPowerOp available: {is_custom_op('qonnx.custom_op.general', 'MyPythonPowerOp')}\")",
"execution_count": null
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -462,17 +453,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# register our new op\n",
"general.custom_op[\"MyMixedPowerOp\"] = MyMixedPowerOp\n",
"\n",
"# make graph with new op\n",
"mixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\n",
"mixedop_graph.graph.node"
]
"source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node",
"execution_count": null
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -744,4 +728,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
6 changes: 5 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ install_requires =
importlib-metadata
attrs>=22.2.0
clize>=5.0.1
protobuf==3.20.3
protobuf>=3.20.3
bitstring>=3.1.7
numpy>=1.24.1
onnx>=1.13.0
Expand Down Expand Up @@ -106,6 +106,10 @@ console_scripts =
qonnx-tensor-stats = qonnx.analysis.tensor_stats:main
pytest_randomly.random_seeder =
qonnx = qonnx.util.random_reseed:reseed
# entry points for custom op modules
qonnx_custom_ops =
qonnx = qonnx.custom_op.general
qonnx_channels_last = qonnx.custom_op.channels_last
# Add here console scripts like:
# console_scripts =
# script_name = qonnx.module:function
Expand Down
25 changes: 25 additions & 0 deletions src/qonnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""QONNX package initialization."""

import warnings
from importlib import metadata


def _load_custom_op_entry_points():
"""Import modules registered under the ``qonnx_custom_ops`` entry point."""

try:
eps = metadata.entry_points()
if hasattr(eps, "select"):
eps = eps.select(group="qonnx_custom_ops")
else:
eps = eps.get("qonnx_custom_ops", [])
for ep in eps:
try:
ep.load()
except Exception as e: # pragma: no cover - import failure warning
warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}")
except Exception as e: # pragma: no cover - metadata failure warning
warnings.warn(f"Failed to query custom op entry points: {e}")


_load_custom_op_entry_points()
30 changes: 30 additions & 0 deletions src/qonnx/custom_op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2020 Xilinx, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of Xilinx nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Domain aliases are automatically handled by the registry
# The onnx.brevitas -> qonnx.custom_op.general mapping is built into the registry
2 changes: 1 addition & 1 deletion src/qonnx/custom_op/channels_last/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
"Conv": Conv,
"MaxPool": MaxPool,
"BatchNormalization": BatchNormalization,
}
}
2 changes: 1 addition & 1 deletion src/qonnx/custom_op/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@
"Trunc": Trunc,
"BipolarQuant": BipolarQuant,
"FloatQuant": FloatQuant,
}
}
12 changes: 6 additions & 6 deletions src/qonnx/custom_op/general/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from qonnx.custom_op.general.intquant import IntQuant as Quant
# Import IntQuant to create alias
from qonnx.custom_op.general.intquant import IntQuant

# Re-export functions from intquant for backward compatibility
from qonnx.custom_op.general.intquant import int_quant as quant
from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode

Quant = Quant
quant = quant
max_int = max_int
min_int = min_int
resolve_rounding_mode = resolve_rounding_mode
# Create alias for backward compatibility - Quant is just IntQuant
Quant = IntQuant
3 changes: 2 additions & 1 deletion src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from qonnx.util.onnx import is_eltwise_optype

# Standard ONNX nodes which require a ChannelsLast data format to function properly
_channelsLast_node_types = list(channels_last.custom_op.keys())
# use the list of exported op names from the channels_last package
_channelsLast_node_types = list(channels_last.__all__)

# Nodes, which do not modify the shape of the tensor
# And modify all values in the same way.
Expand Down
8 changes: 8 additions & 0 deletions src/qonnx/transformation/extract_quant_scale_zeropt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(inp_scaled)
inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm])
if hasattr(node, "metadata_props"):
inp_scale_node.metadata_props.extend(node.metadata_props)
graph.node.append(inp_scale_node)
# create new Mul node
# remove scale from Quant node
Expand All @@ -87,6 +89,8 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(inp_zeropt)
inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm])
if hasattr(node, "metadata_props"):
inp_zeropt_node.metadata_props.extend(node.metadata_props)
graph.node.append(inp_zeropt_node)
# remove zeropt from Quant node
new_zeropt_nm = model.make_new_valueinfo_name()
Expand All @@ -108,6 +112,8 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(out_zeropt)
out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output])
if hasattr(node, "metadata_props"):
out_zeropt_node.metadata_props.extend(node.metadata_props)
last_node.output[0] = out_zeropt_nm
graph.node.append(out_zeropt_node)
# important: when tracking a pointer to newly added nodes,
Expand All @@ -127,6 +133,8 @@ def apply(self, model: ModelWrapper):
last_node.output[0] = out_scale_nm
graph.value_info.append(out_scale)
out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output])
if hasattr(node, "metadata_props"):
out_scale_node.metadata_props.extend(node.metadata_props)
graph.node.append(out_scale_node)

if extract_scale or extract_zeropt:
Expand Down
13 changes: 12 additions & 1 deletion src/qonnx/transformation/gemm_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def apply(self, model):
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name])
if hasattr(n, "metadata_props"):
inp_trans_node.metadata_props.extend(n.metadata_props)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[0])
Expand All @@ -98,6 +100,8 @@ def apply(self, model):
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name])
if hasattr(n, "metadata_props"):
inp_trans_node.metadata_props.extend(n.metadata_props)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
# Copy over the datatype
Expand All @@ -109,6 +113,8 @@ def apply(self, model):

# Insert MatMul: A * B
matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]])
if hasattr(n, "metadata_props"):
matMul_node.metadata_props.extend(n.metadata_props)
graph.node.insert(running_node_index, matMul_node)
matMul_node = graph.node[running_node_index]
running_node_index += 1
Expand Down Expand Up @@ -144,6 +150,8 @@ def apply(self, model):
[act_mul_tensor.name, mul_tensor.name],
[n.output[0]],
)
if hasattr(n, "metadata_props"):
mul_node.metadata_props.extend(n.metadata_props)
graph.node.insert(running_node_index, mul_node)
mul_node_main_branch = graph.node[running_node_index]
running_node_index += 1
Expand Down Expand Up @@ -175,6 +183,8 @@ def apply(self, model):
[n.input[2], mul_tensor.name],
[act_mul_tensor.name],
)
if hasattr(n, "metadata_props"):
mul_node.metadata_props.extend(n.metadata_props)
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[2])
Expand All @@ -196,7 +206,8 @@ def apply(self, model):
[act_add_tensor.name, n.input[2]],
[n.output[0]],
)

if hasattr(n, "metadata_props"):
add_node.metadata_props.extend(n.metadata_props)
graph.node.insert(running_node_index, add_node)
running_node_index += 1

Expand Down
32 changes: 22 additions & 10 deletions src/qonnx/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,38 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import onnx

from qonnx.custom_op.registry import getCustomOp


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
# update this code to handle export configs from subgraphs
# where the subgraph is found in a node's attribute as a graph type
def extract_model_config(model, attr_names_to_extract):
"""Create a dictionary with layer name -> attribute mappings extracted from the
model. The created dictionary can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

cfg = dict()
cfg["Defaults"] = dict()
for n in model.graph.node:
oi = getCustomOp(n)
layer_dict = dict()
for attr in attr_names_to_extract:
try:
layer_dict[attr] = oi.get_nodeattr(attr)
except AttributeError:
pass
for attr in n.attribute:
if attr.type == onnx.AttributeProto.GRAPH: # Graph type
# If the attribute is a graph, we need to extract the attributes from the subgraph
cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract))
elif attr.name in attr_names_to_extract:
# If the attribute name is in the list, we can add it directly
layer_dict[attr.name] = oi.get_nodeattr(attr.name)
if len(layer_dict) > 0:
cfg[n.name] = layer_dict
return cfg


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

with open(json_filename, "w") as f:
json.dump(cfg, f, indent=2)
json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2)
9 changes: 6 additions & 3 deletions tests/custom_op/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@
import numpy as np
import onnx.parser as oprs

import qonnx.custom_op.general as general
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.registry import getCustomOp
from qonnx.custom_op.registry import getCustomOp, add_op_to_domain


class AttrTestOp(CustomOp):
Expand Down Expand Up @@ -60,7 +59,9 @@ def verify_node(self):


def test_attr():
general.custom_op["AttrTestOp"] = AttrTestOp
# Add the test op to the domain
add_op_to_domain("qonnx.custom_op.general", AttrTestOp)

ishp = (1, 10)
wshp = (1, 3)
oshp = wshp
Expand All @@ -87,6 +88,8 @@ def test_attr():
"""
model = oprs.parse_model(input)
model = ModelWrapper(model)

# Now getCustomOp should find it through the manual registry
inst = getCustomOp(model.graph.node[0])

w_prod = inst.get_nodeattr("tensor_attr")
Expand Down
Loading
Loading