diff --git a/README.md b/README.md
index ecf6bd47..108b498c 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
-QONNX (Quantized ONNX) introduces several custom operators -- [`IntQuant`](docs/qonnx-custom-ops/intquant_op.md), [`FloatQuant`](docs/qonnx-custom-ops/floatquant_op.md), [`BipolarQuant`](docs/qonnx-custom-ops/bipolar_quant_op.md), and [`Trunc`](docs/qonnx-custom-ops/trunc_op.md) -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
+QONNX (Quantized ONNX) introduces several [custom operators](docs/qonnx-custom-ops/overview.md) -- `IntQuant`, `FloatQuant`, `BipolarQuant`, and `Trunc` -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
* Representation of binary, ternary, 3-bit, 4-bit, 6-bit or any other integer/fixed-point quantization.
* Representation of minifloat quantization with configurable exponent and mantissa bits.
* Quantization is an operator itself, and can be applied to any parameter or layer input.
@@ -29,9 +29,7 @@ This repository contains a set of Python utilities to work with QONNX models, in
### Operator definitions
-* [Quant](docs/qonnx-custom-ops/quant_op.md) for 2-to-arbitrary-bit quantization, with scaling and zero-point
-* [BipolarQuant](docs/qonnx-custom-ops/bipolar_quant_op.md) for 1-bit (bipolar) quantization, with scaling and zero-point
-* [Trunc](docs/qonnx-custom-ops/trunc_op.md) for truncating to a specified number of bits, with scaling and zero-point
+Please see the [custom operator overview](docs/qonnx-custom-ops/overview.md) table for more details.
### Installation
diff --git a/docs/qonnx-custom-ops/bipolar_quant_op.md b/docs/qonnx-custom-ops/bipolarquant_v1.md
similarity index 94%
rename from docs/qonnx-custom-ops/bipolar_quant_op.md
rename to docs/qonnx-custom-ops/bipolarquant_v1.md
index 3a70458e..03c0c01e 100644
--- a/docs/qonnx-custom-ops/bipolar_quant_op.md
+++ b/docs/qonnx-custom-ops/bipolarquant_v1.md
@@ -5,7 +5,7 @@ Additionally, takes one float as input, which define the scaling.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_v1.md
similarity index 98%
rename from docs/qonnx-custom-ops/floatquant_op.md
rename to docs/qonnx-custom-ops/floatquant_v1.md
index fc51b75f..4536194b 100644
--- a/docs/qonnx-custom-ops/floatquant_op.md
+++ b/docs/qonnx-custom-ops/floatquant_v1.md
@@ -16,7 +16,7 @@ special (symbolic) values. This makes it nontrivial to infer the maximum represe
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/intquant_op.md b/docs/qonnx-custom-ops/intquant_v1.md
similarity index 97%
rename from docs/qonnx-custom-ops/intquant_op.md
rename to docs/qonnx-custom-ops/intquant_v1.md
index fb627efb..4d15c0ec 100644
--- a/docs/qonnx-custom-ops/intquant_op.md
+++ b/docs/qonnx-custom-ops/intquant_v1.md
@@ -9,11 +9,11 @@ rounding_mode defines how quantized values are rounded.
Notes:
* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
-* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
+* This operator does not work for binary or bipolar quantization, for this purpose the simpler `BipolarQuant` node exists.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/overview.md b/docs/qonnx-custom-ops/overview.md
new file mode 100644
index 00000000..dfb93c38
--- /dev/null
+++ b/docs/qonnx-custom-ops/overview.md
@@ -0,0 +1,13 @@
+## Operator Schemas
+
+This file lists the QONNX custom operators, similar to `Operators.md` for the ONNX standard.
+It is manually updated, since QONNX custom operators are relatively few in number.
+
+### qonnx.custom_op.general
+
+|**Operator**|**Since version**||
+|-|-|-|
+|BipolarQuant|1|
+|FloatQuant|1|
+|IntQuant|1|
+|Trunc|2, 1|
diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_v1.md
similarity index 96%
rename from docs/qonnx-custom-ops/trunc_op.md
rename to docs/qonnx-custom-ops/trunc_v1.md
index 1b5f0d04..04b88443 100644
--- a/docs/qonnx-custom-ops/trunc_op.md
+++ b/docs/qonnx-custom-ops/trunc_v1.md
@@ -6,7 +6,7 @@ The attribute rounding_mode defines how truncated values are rounded.
#### Version
-This operator is not part of the ONNX standard and is not currently versioned.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.
#### Attributes
diff --git a/docs/qonnx-custom-ops/trunc_v2.md b/docs/qonnx-custom-ops/trunc_v2.md
new file mode 100644
index 00000000..d716c6c2
--- /dev/null
+++ b/docs/qonnx-custom-ops/trunc_v2.md
@@ -0,0 +1,144 @@
+### **Trunc**
+
+Truncates the values of one input data (Tensor) at a specified bitwidth and produces one output data (Tensor).
+Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization.
+The attribute rounding_mode defines how truncated values are rounded.
+
+#### Version
+
+This operator is not part of the ONNX standard.
+The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.
+
+#### Attributes
+
+
+- rounding_mode : string (default is "FLOOR")
+- Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+- signed : int (default is 1)
+- Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+- narrow : int (default is 0)
+- Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+
+
+#### Inputs
+
+
+- X (differentiable) : tensor(float32)
+- input tensor to truncate
+- scale : float32
+- The scale factor at the input of the truncation
+- zeropt : float32
+- The zero-point at the input of the truncation
+- in_bitwidth : int32
+- The number of bits used at the input of the truncation
+- out_scale : float32
+- The scale factor of the output of the truncation
+- out_bitwidth : int32
+- The number of bits used at the output of the truncation
+
+
+
+#### Outputs
+
+
+- Y (differentiable) : tensor(float32)
+- Output tensor
+
+
+
+#### Examples
+
+Trunc
+
+```python
+from onnx import helper
+import numpy as np
+
+# Define node settings and input
+x = np.random.randn(100).astype(np.float32)*10.
+scale = np.array(1.)
+zeropt = np.array(0.)
+in_bitwidth = np.array(10)
+out_bitwidth = np.array(4)
+rounding_mode = "ROUND"
+
+# Create node
+node = helper.make_node(
+ 'Trunc',
+ domain='finn.custom_op.general',
+ inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'],
+ outputs=['y'],
+ rounding_mode=rounding_mode,
+)
+
+# Execute the same settings with the reference implementation (trunc)
+# See the sample implementation for more details on trunc.
+output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode)
+
+# Execute node and compare
+expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc')
+
+```
+
+
+
+
+#### Sample Implementation
+
+
+Trunc
+
+```python
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+
+def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
+
+ # Scaling
+ y = inp_tensor / scale
+ y = y + zeropt
+ # Rounding
+ y = np.round(y)
+ # Rescale
+ trunc_scale = 2 ** np.round(
+ np.log2(output_scale / scale)
+ ) # Trunc scale should be a power-of-two - ensure that is the case
+ y = y / trunc_scale
+
+ # Clamping
+ min_int_val = min_int(signed, narrow, output_bit_width)
+ max_int_val = max_int(signed, narrow, output_bit_width)
+ y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
+ y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
+ # To int (truncate)
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y = rounding_fx(y)
+
+ # Rescale
+ output_zeropt = zeropt / trunc_scale # Rescale zero-point
+ y = y - output_zeropt
+ y = y * output_scale
+
+ return y
+
+def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string of Quant and Trunc ops
+ to the corresponding numpy functions."""
+ if mode_string == "ROUND":
+ return np.round
+ elif mode_string == "CEIL":
+ return np.ceil
+ elif mode_string == "FLOOR":
+ return np.floor
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
+
+```
+
+
diff --git a/src/qonnx/core/execute_custom_node.py b/src/qonnx/core/execute_custom_node.py
index 7acf3792..cd6bb605 100644
--- a/src/qonnx/core/execute_custom_node.py
+++ b/src/qonnx/core/execute_custom_node.py
@@ -27,10 +27,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import qonnx.custom_op.registry as registry
-from qonnx.util.basic import get_preferred_onnx_opset
-def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()):
+def execute_custom_node(node, context, graph, onnx_opset_version):
"""Call custom implementation to execute a single custom node.
Input/output provided via context."""
op_type = node.op_type
diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py
index c82a0fee..2ba2984a 100644
--- a/src/qonnx/core/modelwrapper.py
+++ b/src/qonnx/core/modelwrapper.py
@@ -39,6 +39,7 @@
import qonnx.util.basic as util
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
+from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
@@ -737,3 +738,24 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)
+
+ def get_opset_imports(self):
+ """Returns a list of imported opsets as a {domain, version} dictionary."""
+ return {opset.domain: opset.version for opset in self._model_proto.opset_import}
+
+ def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()):
+ """Return CustomOp instance for given node, respecting the
+ imported opset version in the model protobuf. If the node's domain
+ is not found in the model's opset imports, fallback_customop_version
+ will be used."""
+ opset_imports = self.get_opset_imports()
+ try:
+ opset_import = opset_imports[node.domain]
+ return getCustomOp(node, onnx_opset_version=opset_import)
+ except KeyError:
+ # domain not found in imports, use fallback version
+ warnings.warn(
+ f"Domain {node.domain} not found in model opset imports, "
+ f"using fallback_customop_version={fallback_customop_version}"
+ )
+ return getCustomOp(node, onnx_opset_version=fallback_customop_version)
diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py
index a8f4774c..ecb808be 100644
--- a/src/qonnx/core/onnx_exec.py
+++ b/src/qonnx/core/onnx_exec.py
@@ -36,7 +36,7 @@
import qonnx.analysis.topology as ta
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.util.basic import (
- get_preferred_onnx_opset,
+ get_preferred_qonnx_opset,
get_sanitize_quant_tensors,
is_finn_op,
qonnx_make_model,
@@ -44,7 +44,7 @@
)
-def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()):
+def execute_node(node, context, graph, opset_version, return_full_exec_context=False):
"""Executes a single node by using onnxruntime or with a custom function.
Input/output provided via context."""
@@ -158,7 +158,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
model_exec_mode = model.get_metadata_prop("exec_mode")
if (model_exec_mode is None) or (model_exec_mode == ""):
# extract opset version for node-by-node execution
- opset_version = model.model.opset_import[0].version
+ opset_imports = model.get_opset_imports()
# execute the model node by node
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
@@ -176,7 +176,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
if get_sanitize_quant_tensors() != 0:
# round input values to match quantization annotation
execution_context = sanitize_quant_values(model, node.input, execution_context)
- execute_node(node, execution_context, graph, return_full_exec_context, opset_version)
+ if node.domain in opset_imports:
+ opset_version = opset_imports[node.domain]
+ else:
+ opset_version = get_preferred_qonnx_opset()
+ execute_node(node, execution_context, graph, opset_version, return_full_exec_context)
if get_sanitize_quant_tensors() != 0:
# round output values to quantization annotation
execution_context = sanitize_quant_values(model, node.output, execution_context)
diff --git a/src/qonnx/custom_op/base.py b/src/qonnx/custom_op/base.py
index 775d9f95..9cb0de11 100644
--- a/src/qonnx/custom_op/base.py
+++ b/src/qonnx/custom_op/base.py
@@ -30,7 +30,7 @@
import onnx.numpy_helper as np_helper
from abc import ABC, abstractmethod
-from qonnx.util.basic import get_by_name, get_preferred_onnx_opset
+from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset
class CustomOp(ABC):
@@ -38,7 +38,7 @@ class CustomOp(ABC):
every custom node should have. Some as abstract methods, these have to be
filled when writing a new custom op node."""
- def __init__(self, onnx_node, onnx_opset_version=get_preferred_onnx_opset()):
+ def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
super().__init__()
self.onnx_node = onnx_node
self.onnx_opset_version = onnx_opset_version
diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py
index f1d7c39b..9ffd4e54 100644
--- a/src/qonnx/custom_op/channels_last/__init__.py
+++ b/src/qonnx/custom_op/channels_last/__init__.py
@@ -2,8 +2,28 @@
from qonnx.custom_op.channels_last.conv import Conv
from qonnx.custom_op.channels_last.max_pool import MaxPool
-custom_op = dict()
+# channels-last ops are defined by the underlying ONNX standard op
+# thus, we can define them for any version of the original op
+# so we emulate a custom op dictionary that mimics the support for any
+# {ChannelsLastOp}_vX instead of hardcoding what versions are supported
-custom_op["Conv"] = Conv
-custom_op["MaxPool"] = MaxPool
-custom_op["BatchNormalization"] = BatchNormalization
+
+class ChannelsLastCustomOpDict(dict):
+ def __init__(self):
+ self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization}
+
+ def __getitem__(self, key):
+ base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13)
+ if base_key in self._custom_ops:
+ return self._custom_ops[base_key]
+ raise KeyError(f"Channels-last CustomOp '{key}' not found.")
+
+ def __contains__(self, key):
+ base_key = key.split("_v")[0]
+ return base_key in self._custom_ops
+
+ def keys(self):
+ return self._custom_ops.keys()
+
+
+custom_op = ChannelsLastCustomOpDict()
diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py
index 9b14ea8a..c670c1a3 100644
--- a/src/qonnx/custom_op/general/__init__.py
+++ b/src/qonnx/custom_op/general/__init__.py
@@ -35,7 +35,7 @@
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from qonnx.custom_op.general.multithreshold import MultiThreshold
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
-from qonnx.custom_op.general.trunc import Trunc
+from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
custom_op = dict()
@@ -49,6 +49,21 @@
custom_op["Im2Col"] = Im2Col
custom_op["IntQuant"] = IntQuant
custom_op["Quant"] = IntQuant
-custom_op["Trunc"] = Trunc
+custom_op["Trunc"] = Trunc_v1
custom_op["BipolarQuant"] = BipolarQuant
custom_op["FloatQuant"] = FloatQuant
+
+custom_op["DebugMarker_v1"] = DebugMarker
+custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d
+custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC
+custom_op["GenericPartition_v1"] = GenericPartition
+custom_op["MultiThreshold_v1"] = MultiThreshold
+custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul
+custom_op["Im2Col_v1"] = Im2Col
+custom_op["IntQuant_v1"] = IntQuant
+custom_op["Quant_v1"] = IntQuant
+custom_op["Trunc_v1"] = Trunc_v1
+custom_op["BipolarQuant_v1"] = BipolarQuant
+custom_op["FloatQuant_v1"] = FloatQuant
+
+custom_op["Trunc_v2"] = Trunc_v2
diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py
index c0e24071..00617dcf 100644
--- a/src/qonnx/custom_op/general/quantavgpool2d.py
+++ b/src/qonnx/custom_op/general/quantavgpool2d.py
@@ -33,7 +33,7 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim
-from qonnx.util.basic import qonnx_make_model
+from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model
class QuantAvgPool2d(CustomOp):
@@ -132,7 +132,7 @@ def execute_node(self, context, graph):
outputs=[outp],
)
- opset_version = self.onnx_opset_version
+ opset_version = get_preferred_onnx_opset()
opset_imports = [helper.make_opsetid("", opset_version)]
onnx_kwargs = {"opset_imports": opset_imports}
model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs)
diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py
index 8e2eaa19..10c7e992 100644
--- a/src/qonnx/custom_op/general/trunc.py
+++ b/src/qonnx/custom_op/general/trunc.py
@@ -31,10 +31,99 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
-from qonnx.custom_op.general.quant import resolve_rounding_mode
+from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode
+from qonnx.util.basic import get_preferred_qonnx_opset
-def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
+def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
+ # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
+
+ # Scaling
+ y = inp_tensor / scale
+ y = y + zeropt
+ # Rounding
+ y = np.round(y)
+ # Rescale
+ trunc_scale = 2 ** np.round(
+ np.log2(output_scale / scale)
+ ) # Trunc scale should be a power-of-two - ensure that is the case
+ y = y / trunc_scale
+
+ # Clamping
+ min_int_val = min_int(signed, narrow, output_bit_width)
+ max_int_val = max_int(signed, narrow, output_bit_width)
+ y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
+ y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
+ # To int (truncate)
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y = rounding_fx(y)
+
+ # Rescale
+ output_zeropt = zeropt / trunc_scale # Rescale zero-point
+ y = y - output_zeropt
+ y = y * output_scale
+
+ return y
+
+
+class Trunc_v2(CustomOp):
+ """Generic truncation operation for QONNX. Takes four inputs:
+ - input tensor to truncate
+ - the scale
+ - the zero-point
+ - the truncation scale
+ - the truncation bit-width
+
+ The output is a tensor of the same shape as the input tensor, with truncated
+ values.
+ """
+
+ def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
+ super().__init__(onnx_node, onnx_opset_version)
+ # override any specified opset version, this instance is v2
+ self.onnx_opset_version = 2
+
+ def get_nodeattr_types(self):
+ return {
+ # The rounding mode, which is used for the trunc function
+ "rounding_mode": ("s", True, "FLOOR"),
+ "narrow": ("i", False, 0, {0, 1}),
+ "signed": ("i", False, 1, {0, 1}),
+ }
+
+ def make_shape_compatible_op(self, model):
+ node = self.onnx_node
+ return helper.make_node("Identity", [node.input[0]], [node.output[0]])
+
+ def infer_node_datatype(self, model):
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], DataType["FLOAT32"])
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ zeropt = context[node.input[2]]
+ input_bit_width = context[node.input[3]]
+ output_scale = context[node.input[4]]
+ output_bit_width = context[node.input[5]]
+ # save attributes
+ rounding_mode = self.get_nodeattr("rounding_mode")
+ narrow = self.get_nodeattr("narrow")
+ signed = self.get_nodeattr("signed")
+ # calculate output
+ ret = trunc_v2(
+ inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
+ )
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def verify_node(self):
+ pass
+
+
+def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
# Scaling
@@ -58,7 +147,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding
return y
-class Trunc(CustomOp):
+class Trunc_v1(CustomOp):
"""Generic truncation operation for QONNX. Takes four inputs:
- input tensor to truncate
- the scale
@@ -69,6 +158,11 @@ class Trunc(CustomOp):
values.
"""
+ def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
+ super().__init__(onnx_node, onnx_opset_version)
+ # override any specified opset version, this instance is v1
+ self.onnx_opset_version = 1
+
def get_nodeattr_types(self):
return {
# The rounding mode, which is used for the trunc function
@@ -94,7 +188,7 @@ def execute_node(self, context, graph):
# save attributes
rounding_mode = self.get_nodeattr("rounding_mode")
# calculate output
- ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
+ ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
# set context according to output name
context[node.output[0]] = ret
diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py
index 3540bb5a..258e9ab0 100644
--- a/src/qonnx/custom_op/registry.py
+++ b/src/qonnx/custom_op/registry.py
@@ -28,11 +28,12 @@
import importlib
-from qonnx.util.basic import get_preferred_onnx_opset
-
-def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True):
- "Return a QONNX CustomOp instance for the given ONNX node, if it exists."
+def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True):
+ "Return a QONNX CustomOp wrapper for the given ONNX node and given opset version,"
+ "if it exists. If opset version is None, the default handler for the op type will be used. "
+ "If version is specified but the exact version match isn't available, the highest available version "
+ "smaller than the requested version will be used."
op_type = node.op_type
domain = node.domain
if brevitas_exception:
@@ -40,11 +41,34 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex
domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general")
try:
opset_module = importlib.import_module(domain)
- assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain
- inst_wrapper = opset_module.custom_op[op_type]
- inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version)
+ assert isinstance(opset_module.custom_op, dict), "custom_op dict not found in Python module %s" % domain
+ found_opset_version = None
+ if onnx_opset_version is None:
+ inst_wrapper = opset_module.custom_op[op_type]
+ else:
+ op_type_with_version = op_type + "_v" + str(onnx_opset_version)
+ if op_type_with_version in opset_module.custom_op:
+ # priority: if it exists, load the versioned CustomOp wrapper
+ inst_wrapper = opset_module.custom_op[op_type_with_version]
+ found_opset_version = onnx_opset_version
+ else:
+ # when the exact version match is not found
+ # version handling: use highest available version smaller than requested version
+ available_versions = [
+ int(k.split("_v")[-1]) for k in opset_module.custom_op.keys() if k.startswith(op_type + "_v")
+ ]
+ suitable_versions = [v for v in available_versions if v <= onnx_opset_version]
+ if suitable_versions:
+ highest_version = max(suitable_versions)
+ inst_wrapper = opset_module.custom_op[f"{op_type}_v{highest_version}"]
+ found_opset_version = highest_version
+ else:
+ raise Exception(
+ "Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain)
+ )
+ inst = inst_wrapper(node, onnx_opset_version=found_opset_version)
return inst
except ModuleNotFoundError:
raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain)
except KeyError:
- raise Exception("Op %s not found in custom opset %s" % (op_type, domain))
+ raise Exception("Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain))
diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py
index 175af058..c352238c 100644
--- a/src/qonnx/transformation/channels_last.py
+++ b/src/qonnx/transformation/channels_last.py
@@ -270,8 +270,13 @@ def apply(self, model):
# Attach to original node
n.output[i] = outp_trans_in
- # Modify domain
+ # Modify node domain
n.domain = "qonnx.custom_op.channels_last"
+ opset_imports = model.get_opset_imports()
+ # Ensure channels_last domain is imported in model
+ if "qonnx.custom_op.channels_last" not in opset_imports:
+ onnx_opset = opset_imports[""]
+ model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset))
# Set modified flag
graph_modified = True
diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py
index 894d7ea6..ff0c11db 100644
--- a/src/qonnx/transformation/fixedpt_quantize.py
+++ b/src/qonnx/transformation/fixedpt_quantize.py
@@ -48,7 +48,8 @@ class FixedPointQuantizeParamsFromDict(Transformation):
data type or its canonical name
rounding_mode: Rounding mode used for conversion into fixed point.
Default is "ROUND",
- possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"]
+ possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN",
+ "HALF_UP", "HALF_DOWN"]
"""
def __init__(self, fixedpt_dict, rounding_mode="ROUND"):
diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py
index 3a3ce2af..e756366d 100644
--- a/src/qonnx/util/basic.py
+++ b/src/qonnx/util/basic.py
@@ -51,11 +51,19 @@ def get_preferred_onnx_opset():
return 11
+def get_preferred_qonnx_opset():
+ "Return preferred ONNX opset version for QONNX"
+ return 1
+
+
def qonnx_make_model(graph_proto, **kwargs):
"Wrapper around ONNX make_model with preferred qonnx opset version"
opset_imports = kwargs.pop("opset_imports", None)
if opset_imports is None:
- opset_imports = [make_opsetid("", get_preferred_onnx_opset())]
+ opset_imports = [
+ make_opsetid("", get_preferred_onnx_opset()),
+ make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()),
+ ]
kwargs["opset_imports"] = opset_imports
else:
kwargs["opset_imports"] = opset_imports
diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py
index 8eec7156..54b71754 100644
--- a/tests/core/test_custom_onnx_exec.py
+++ b/tests/core/test_custom_onnx_exec.py
@@ -32,6 +32,8 @@
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.custom_op.registry import getCustomOp
+mt_node_version = 1
+
def test_execute_custom_node_multithreshold():
inputs = np.ndarray(
@@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold():
execution_context["v"] = inputs
execution_context["thresholds"] = threshold_values
- ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+ ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
outputs = np.ndarray(
shape=(6, 3, 2, 2),
@@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold():
)
graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
- ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+ ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
outputs_scaled = 2.0 * outputs - 1.0
assert (execution_context["out"] == outputs_scaled).all()
@@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold():
execution_context["v"] = inputs_nhwc
graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc])
- ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+ ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version)
assert (execution_context["out"] == outputs_nhwc).all()
# check the set of allowed values
op_inst = getCustomOp(node_def)
diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py
index 722f0fb1..995bcb17 100644
--- a/tests/core/test_modelwrapper.py
+++ b/tests/core/test_modelwrapper.py
@@ -68,6 +68,7 @@ def test_modelwrapper():
inp_sparsity = {"dw": {"kernel_shape": [3, 3]}}
model.set_tensor_sparsity(first_conv_iname, inp_sparsity)
assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity
+ assert model.get_opset_imports() == {"": 8}
def test_modelwrapper_set_get_rm_initializer():
diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py
new file mode 100644
index 00000000..3efbde24
--- /dev/null
+++ b/tests/custom_op/test_customop_version.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2025 Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of qonnx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import onnx.parser as oprs
+
+import qonnx.custom_op.general as general
+from qonnx.core.modelwrapper import ModelWrapper
+from qonnx.custom_op.base import CustomOp
+from qonnx.custom_op.registry import getCustomOp
+
+
+class VerTestOp_v1(CustomOp):
+ def get_nodeattr_types(self):
+ my_attrs = {"v1_attr": ("i", True, 0)}
+ return my_attrs
+
+ def make_shape_compatible_op(self, model):
+ ishape = model.get_tensor_shape(self.onnx_node.input[0])
+ return super().make_const_shape_op(ishape)
+
+ def infer_node_datatype(self, model):
+ node = self.onnx_node
+ # data type stays the same
+ dtype = model.get_tensor_datatype(node.input[0])
+ model.set_tensor_datatype(node.output[0], dtype)
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ context[node.output[0]] = context[node.input[0]]
+
+ def verify_node(self):
+ pass
+
+
+class VerTestOp_v2(VerTestOp_v1):
+ def get_nodeattr_types(self):
+ my_attrs = {"v2_attr": ("i", True, 0)}
+ return my_attrs
+
+
+class VerTestOp_v3(VerTestOp_v2):
+ def get_nodeattr_types(self):
+ my_attrs = {"v3_attr": ("i", True, 0)}
+ return my_attrs
+
+
+def make_vertest_model(vertest_ver, no_opset_import):
+ ishp = (1, 10)
+ oshp = ishp
+ ishp_str = str(list(ishp))
+ oshp_str = str(list(oshp))
+ if no_opset_import:
+ opset_import = ""
+ else:
+ opset_import = f', "qonnx.custom_op.general" : {vertest_ver}'
+ input = f"""
+ <
+ ir_version: 7,
+ opset_import: ["" : 9{opset_import}]
+ >
+ agraph (float{ishp_str} in0) => (float{oshp_str} out0)
+ {{
+ out0 = qonnx.custom_op.general.VerTestOp<
+ v{vertest_ver}_attr={vertest_ver}
+ >(in0)
+ }}
+ """
+ model = oprs.parse_model(input)
+ model = ModelWrapper(model)
+ return model
+
+
+def test_customop_version():
+ # unspecified version defaults to v1 implementation
+ general.custom_op["VerTestOp"] = VerTestOp_v1
+ # v1 version is also explicitly registered
+ general.custom_op["VerTestOp_v1"] = VerTestOp_v1
+ general.custom_op["VerTestOp_v2"] = VerTestOp_v2
+ general.custom_op["VerTestOp_v3"] = VerTestOp_v3
+
+ # if onnx is lacking the opset import, should default to v1 handler
+ # (since we set custom_op["VerTestOp"] = VerTestOp_v1)
+ model = make_vertest_model(1, True)
+ inst = getCustomOp(model.graph.node[0])
+ assert isinstance(inst, VerTestOp_v1)
+ # alternatively, when using ModelWrapper.get_customop_wrapper and onnx is
+ # lacking the opset import, should fall back to the specified version
+ inst = model.get_customop_wrapper(model.graph.node[0], fallback_customop_version=2)
+ assert isinstance(inst, VerTestOp_v2)
+
+ for ver in [1, 2, 3]:
+ model = make_vertest_model(ver, False)
+ # use ModelWrapper.get_customop_wrapper for implicit
+ # fetching of op version
+ inst = model.get_customop_wrapper(model.graph.node[0])
+ assert inst.get_nodeattr(f"v{ver}_attr") == ver
+ assert inst.onnx_opset_version == ver
+ # explicitly specify onnx_opset_version in getCustomOp
+ # note: new code should avoid calling getCustomOp directly like this
+ # and instead use ModelWrapper.get_customop_wrapper
+ inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver)
+ assert inst.get_nodeattr(f"v{ver}_attr") == ver
+ assert inst.onnx_opset_version == ver
+ # unspecified version getCustomOp should default to v1 handler
+ model = make_vertest_model(1, False)
+ inst = getCustomOp(model.graph.node[0])
+ assert isinstance(inst, VerTestOp_v1)
+ # requesting v4 should return largest available version (v3 in this case)
+ model = make_vertest_model(3, False)
+ inst = getCustomOp(model.graph.node[0], onnx_opset_version=4)
+ assert isinstance(inst, VerTestOp_v3)
+ assert inst.onnx_opset_version == 3