diff --git a/README.md b/README.md index ecf6bd47..108b498c 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ QONNX example -QONNX (Quantized ONNX) introduces several custom operators -- [`IntQuant`](docs/qonnx-custom-ops/intquant_op.md), [`FloatQuant`](docs/qonnx-custom-ops/floatquant_op.md), [`BipolarQuant`](docs/qonnx-custom-ops/bipolar_quant_op.md), and [`Trunc`](docs/qonnx-custom-ops/trunc_op.md) -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables: +QONNX (Quantized ONNX) introduces several [custom operators](docs/qonnx-custom-ops/overview.md) -- `IntQuant`, `FloatQuant`, `BipolarQuant`, and `Trunc` -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables: * Representation of binary, ternary, 3-bit, 4-bit, 6-bit or any other integer/fixed-point quantization. * Representation of minifloat quantization with configurable exponent and mantissa bits. * Quantization is an operator itself, and can be applied to any parameter or layer input. @@ -29,9 +29,7 @@ This repository contains a set of Python utilities to work with QONNX models, in ### Operator definitions -* [Quant](docs/qonnx-custom-ops/quant_op.md) for 2-to-arbitrary-bit quantization, with scaling and zero-point -* [BipolarQuant](docs/qonnx-custom-ops/bipolar_quant_op.md) for 1-bit (bipolar) quantization, with scaling and zero-point -* [Trunc](docs/qonnx-custom-ops/trunc_op.md) for truncating to a specified number of bits, with scaling and zero-point +Please see the [custom operator overview](docs/qonnx-custom-ops/overview.md) table for more details. ### Installation diff --git a/docs/qonnx-custom-ops/bipolar_quant_op.md b/docs/qonnx-custom-ops/bipolarquant_v1.md similarity index 94% rename from docs/qonnx-custom-ops/bipolar_quant_op.md rename to docs/qonnx-custom-ops/bipolarquant_v1.md index 3a70458e..03c0c01e 100644 --- a/docs/qonnx-custom-ops/bipolar_quant_op.md +++ b/docs/qonnx-custom-ops/bipolarquant_v1.md @@ -5,7 +5,7 @@ Additionally, takes one float as input, which define the scaling. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_v1.md similarity index 98% rename from docs/qonnx-custom-ops/floatquant_op.md rename to docs/qonnx-custom-ops/floatquant_v1.md index fc51b75f..4536194b 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_v1.md @@ -16,7 +16,7 @@ special (symbolic) values. This makes it nontrivial to infer the maximum represe #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/intquant_op.md b/docs/qonnx-custom-ops/intquant_v1.md similarity index 97% rename from docs/qonnx-custom-ops/intquant_op.md rename to docs/qonnx-custom-ops/intquant_v1.md index fb627efb..4d15c0ec 100644 --- a/docs/qonnx-custom-ops/intquant_op.md +++ b/docs/qonnx-custom-ops/intquant_v1.md @@ -9,11 +9,11 @@ rounding_mode defines how quantized values are rounded. Notes: * This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models. -* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists. +* This operator does not work for binary or bipolar quantization, for this purpose the simpler `BipolarQuant` node exists. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/overview.md b/docs/qonnx-custom-ops/overview.md new file mode 100644 index 00000000..dfb93c38 --- /dev/null +++ b/docs/qonnx-custom-ops/overview.md @@ -0,0 +1,13 @@ +## Operator Schemas + +This file lists the QONNX custom operators, similar to `Operators.md` for the ONNX standard. +It is manually updated, since QONNX custom operators are relatively few in number. + +### qonnx.custom_op.general + +|**Operator**|**Since version**|| +|-|-|-| +|BipolarQuant|1| +|FloatQuant|1| +|IntQuant|1| +|Trunc|2, 1| diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_v1.md similarity index 96% rename from docs/qonnx-custom-ops/trunc_op.md rename to docs/qonnx-custom-ops/trunc_v1.md index 1b5f0d04..04b88443 100644 --- a/docs/qonnx-custom-ops/trunc_op.md +++ b/docs/qonnx-custom-ops/trunc_v1.md @@ -6,7 +6,7 @@ The attribute rounding_mode defines how truncated values are rounded. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/trunc_v2.md b/docs/qonnx-custom-ops/trunc_v2.md new file mode 100644 index 00000000..d716c6c2 --- /dev/null +++ b/docs/qonnx-custom-ops/trunc_v2.md @@ -0,0 +1,144 @@ +### **Trunc** + +Truncates the values of one input data (Tensor) at a specified bitwidth and produces one output data (Tensor). +Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization. +The attribute rounding_mode defines how truncated values are rounded. + +#### Version + +This operator is not part of the ONNX standard. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2. + +#### Attributes + +
+
rounding_mode : string (default is "FLOOR")
+
Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
signed : int (default is 1)
+
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+
narrow : int (default is 0)
+
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+
+ +#### Inputs + +
+
X (differentiable) : tensor(float32)
+
input tensor to truncate
+
scale : float32
+
The scale factor at the input of the truncation
+
zeropt : float32
+
The zero-point at the input of the truncation
+
in_bitwidth : int32
+
The number of bits used at the input of the truncation
+
out_scale : float32
+
The scale factor of the output of the truncation
+
out_bitwidth : int32
+
The number of bits used at the output of the truncation
+
+ + +#### Outputs + +
+
Y (differentiable) : tensor(float32)
+
Output tensor
+
+ + +#### Examples +
+Trunc + +```python +from onnx import helper +import numpy as np + +# Define node settings and input +x = np.random.randn(100).astype(np.float32)*10. +scale = np.array(1.) +zeropt = np.array(0.) +in_bitwidth = np.array(10) +out_bitwidth = np.array(4) +rounding_mode = "ROUND" + +# Create node +node = helper.make_node( + 'Trunc', + domain='finn.custom_op.general', + inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'], + outputs=['y'], + rounding_mode=rounding_mode, +) + +# Execute the same settings with the reference implementation (trunc) +# See the sample implementation for more details on trunc. +output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode) + +# Execute node and compare +expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc') + +``` + +
+ + +#### Sample Implementation + +
+Trunc + +```python +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): + + # Scaling + y = inp_tensor / scale + y = y + zeropt + # Rounding + y = np.round(y) + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case + y = y / trunc_scale + + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) + rounding_fx = resolve_rounding_mode(rounding_mode) + y = rounding_fx(y) + + # Rescale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale + + return y + +def resolve_rounding_mode(mode_string): + """Resolve the rounding mode string of Quant and Trunc ops + to the corresponding numpy functions.""" + if mode_string == "ROUND": + return np.round + elif mode_string == "CEIL": + return np.ceil + elif mode_string == "FLOOR": + return np.floor + else: + raise ValueError(f"Could not resolve rounding mode called: {mode_string}") + +``` + +
diff --git a/src/qonnx/core/execute_custom_node.py b/src/qonnx/core/execute_custom_node.py index 7acf3792..cd6bb605 100644 --- a/src/qonnx/core/execute_custom_node.py +++ b/src/qonnx/core/execute_custom_node.py @@ -27,10 +27,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import qonnx.custom_op.registry as registry -from qonnx.util.basic import get_preferred_onnx_opset -def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()): +def execute_custom_node(node, context, graph, onnx_opset_version): """Call custom implementation to execute a single custom node. Input/output provided via context.""" op_type = node.op_type diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index c82a0fee..2ba2984a 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -39,6 +39,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -737,3 +738,24 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qa.tensor_name = tensor_name qa.quant_parameter_tensor_names.append(dt) qnt_annotations.append(qa) + + def get_opset_imports(self): + """Returns a list of imported opsets as a {domain, version} dictionary.""" + return {opset.domain: opset.version for opset in self._model_proto.opset_import} + + def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()): + """Return CustomOp instance for given node, respecting the + imported opset version in the model protobuf. If the node's domain + is not found in the model's opset imports, fallback_customop_version + will be used.""" + opset_imports = self.get_opset_imports() + try: + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) + except KeyError: + # domain not found in imports, use fallback version + warnings.warn( + f"Domain {node.domain} not found in model opset imports, " + f"using fallback_customop_version={fallback_customop_version}" + ) + return getCustomOp(node, onnx_opset_version=fallback_customop_version) diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a8f4774c..ecb808be 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -36,7 +36,7 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node from qonnx.util.basic import ( - get_preferred_onnx_opset, + get_preferred_qonnx_opset, get_sanitize_quant_tensors, is_finn_op, qonnx_make_model, @@ -44,7 +44,7 @@ ) -def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()): +def execute_node(node, context, graph, opset_version, return_full_exec_context=False): """Executes a single node by using onnxruntime or with a custom function. Input/output provided via context.""" @@ -158,7 +158,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N model_exec_mode = model.get_metadata_prop("exec_mode") if (model_exec_mode is None) or (model_exec_mode == ""): # extract opset version for node-by-node execution - opset_version = model.model.opset_import[0].version + opset_imports = model.get_opset_imports() # execute the model node by node # we can simply walk down the list since the ONNX spec guarantees that it is # topologically sorted @@ -176,7 +176,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N if get_sanitize_quant_tensors() != 0: # round input values to match quantization annotation execution_context = sanitize_quant_values(model, node.input, execution_context) - execute_node(node, execution_context, graph, return_full_exec_context, opset_version) + if node.domain in opset_imports: + opset_version = opset_imports[node.domain] + else: + opset_version = get_preferred_qonnx_opset() + execute_node(node, execution_context, graph, opset_version, return_full_exec_context) if get_sanitize_quant_tensors() != 0: # round output values to quantization annotation execution_context = sanitize_quant_values(model, node.output, execution_context) diff --git a/src/qonnx/custom_op/base.py b/src/qonnx/custom_op/base.py index 775d9f95..9cb0de11 100644 --- a/src/qonnx/custom_op/base.py +++ b/src/qonnx/custom_op/base.py @@ -30,7 +30,7 @@ import onnx.numpy_helper as np_helper from abc import ABC, abstractmethod -from qonnx.util.basic import get_by_name, get_preferred_onnx_opset +from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset class CustomOp(ABC): @@ -38,7 +38,7 @@ class CustomOp(ABC): every custom node should have. Some as abstract methods, these have to be filled when writing a new custom op node.""" - def __init__(self, onnx_node, onnx_opset_version=get_preferred_onnx_opset()): + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): super().__init__() self.onnx_node = onnx_node self.onnx_opset_version = onnx_opset_version diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..9ffd4e54 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -2,8 +2,28 @@ from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -custom_op = dict() +# channels-last ops are defined by the underlying ONNX standard op +# thus, we can define them for any version of the original op +# so we emulate a custom op dictionary that mimics the support for any +# {ChannelsLastOp}_vX instead of hardcoding what versions are supported -custom_op["Conv"] = Conv -custom_op["MaxPool"] = MaxPool -custom_op["BatchNormalization"] = BatchNormalization + +class ChannelsLastCustomOpDict(dict): + def __init__(self): + self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization} + + def __getitem__(self, key): + base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13) + if base_key in self._custom_ops: + return self._custom_ops[base_key] + raise KeyError(f"Channels-last CustomOp '{key}' not found.") + + def __contains__(self, key): + base_key = key.split("_v")[0] + return base_key in self._custom_ops + + def keys(self): + return self._custom_ops.keys() + + +custom_op = ChannelsLastCustomOpDict() diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 9b14ea8a..c670c1a3 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -35,7 +35,7 @@ from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d -from qonnx.custom_op.general.trunc import Trunc +from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2 from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul custom_op = dict() @@ -49,6 +49,21 @@ custom_op["Im2Col"] = Im2Col custom_op["IntQuant"] = IntQuant custom_op["Quant"] = IntQuant -custom_op["Trunc"] = Trunc +custom_op["Trunc"] = Trunc_v1 custom_op["BipolarQuant"] = BipolarQuant custom_op["FloatQuant"] = FloatQuant + +custom_op["DebugMarker_v1"] = DebugMarker +custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d +custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC +custom_op["GenericPartition_v1"] = GenericPartition +custom_op["MultiThreshold_v1"] = MultiThreshold +custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul +custom_op["Im2Col_v1"] = Im2Col +custom_op["IntQuant_v1"] = IntQuant +custom_op["Quant_v1"] = IntQuant +custom_op["Trunc_v1"] = Trunc_v1 +custom_op["BipolarQuant_v1"] = BipolarQuant +custom_op["FloatQuant_v1"] = FloatQuant + +custom_op["Trunc_v2"] = Trunc_v2 diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index c0e24071..00617dcf 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model class QuantAvgPool2d(CustomOp): @@ -132,7 +132,7 @@ def execute_node(self, context, graph): outputs=[outp], ) - opset_version = self.onnx_opset_version + opset_version = get_preferred_onnx_opset() opset_imports = [helper.make_opsetid("", opset_version)] onnx_kwargs = {"opset_imports": opset_imports} model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 8e2eaa19..10c7e992 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -31,10 +31,99 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.general.quant import resolve_rounding_mode +from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode +from qonnx.util.basic import get_preferred_qonnx_opset -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): +def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): + # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR + + # Scaling + y = inp_tensor / scale + y = y + zeropt + # Rounding + y = np.round(y) + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case + y = y / trunc_scale + + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) + rounding_fx = resolve_rounding_mode(rounding_mode) + y = rounding_fx(y) + + # Rescale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale + + return y + + +class Trunc_v2(CustomOp): + """Generic truncation operation for QONNX. Takes four inputs: + - input tensor to truncate + - the scale + - the zero-point + - the truncation scale + - the truncation bit-width + + The output is a tensor of the same shape as the input tensor, with truncated + values. + """ + + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): + super().__init__(onnx_node, onnx_opset_version) + # override any specified opset version, this instance is v2 + self.onnx_opset_version = 2 + + def get_nodeattr_types(self): + return { + # The rounding mode, which is used for the trunc function + "rounding_mode": ("s", True, "FLOOR"), + "narrow": ("i", False, 0, {0, 1}), + "signed": ("i", False, 1, {0, 1}), + } + + def make_shape_compatible_op(self, model): + node = self.onnx_node + return helper.make_node("Identity", [node.input[0]], [node.output[0]]) + + def infer_node_datatype(self, model): + node = self.onnx_node + model.set_tensor_datatype(node.output[0], DataType["FLOAT32"]) + + def execute_node(self, context, graph): + node = self.onnx_node + # save inputs + inp_tensor = context[node.input[0]] + scale = context[node.input[1]] + zeropt = context[node.input[2]] + input_bit_width = context[node.input[3]] + output_scale = context[node.input[4]] + output_bit_width = context[node.input[5]] + # save attributes + rounding_mode = self.get_nodeattr("rounding_mode") + narrow = self.get_nodeattr("narrow") + signed = self.get_nodeattr("signed") + # calculate output + ret = trunc_v2( + inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode + ) + # set context according to output name + context[node.output[0]] = ret + + def verify_node(self): + pass + + +def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR # Scaling @@ -58,7 +147,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y -class Trunc(CustomOp): +class Trunc_v1(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate - the scale @@ -69,6 +158,11 @@ class Trunc(CustomOp): values. """ + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): + super().__init__(onnx_node, onnx_opset_version) + # override any specified opset version, this instance is v1 + self.onnx_opset_version = 1 + def get_nodeattr_types(self): return { # The rounding mode, which is used for the trunc function @@ -94,7 +188,7 @@ def execute_node(self, context, graph): # save attributes rounding_mode = self.get_nodeattr("rounding_mode") # calculate output - ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) + ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) # set context according to output name context[node.output[0]] = ret diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..258e9ab0 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,11 +28,12 @@ import importlib -from qonnx.util.basic import get_preferred_onnx_opset - -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - "Return a QONNX CustomOp instance for the given ONNX node, if it exists." +def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): + "Return a QONNX CustomOp wrapper for the given ONNX node and given opset version," + "if it exists. If opset version is None, the default handler for the op type will be used. " + "If version is specified but the exact version match isn't available, the highest available version " + "smaller than the requested version will be used." op_type = node.op_type domain = node.domain if brevitas_exception: @@ -40,11 +41,34 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") try: opset_module = importlib.import_module(domain) - assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - inst_wrapper = opset_module.custom_op[op_type] - inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) + assert isinstance(opset_module.custom_op, dict), "custom_op dict not found in Python module %s" % domain + found_opset_version = None + if onnx_opset_version is None: + inst_wrapper = opset_module.custom_op[op_type] + else: + op_type_with_version = op_type + "_v" + str(onnx_opset_version) + if op_type_with_version in opset_module.custom_op: + # priority: if it exists, load the versioned CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type_with_version] + found_opset_version = onnx_opset_version + else: + # when the exact version match is not found + # version handling: use highest available version smaller than requested version + available_versions = [ + int(k.split("_v")[-1]) for k in opset_module.custom_op.keys() if k.startswith(op_type + "_v") + ] + suitable_versions = [v for v in available_versions if v <= onnx_opset_version] + if suitable_versions: + highest_version = max(suitable_versions) + inst_wrapper = opset_module.custom_op[f"{op_type}_v{highest_version}"] + found_opset_version = highest_version + else: + raise Exception( + "Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain) + ) + inst = inst_wrapper(node, onnx_opset_version=found_opset_version) return inst except ModuleNotFoundError: raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) except KeyError: - raise Exception("Op %s not found in custom opset %s" % (op_type, domain)) + raise Exception("Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain)) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..c352238c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -270,8 +270,13 @@ def apply(self, model): # Attach to original node n.output[i] = outp_trans_in - # Modify domain + # Modify node domain n.domain = "qonnx.custom_op.channels_last" + opset_imports = model.get_opset_imports() + # Ensure channels_last domain is imported in model + if "qonnx.custom_op.channels_last" not in opset_imports: + onnx_opset = opset_imports[""] + model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset)) # Set modified flag graph_modified = True diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py index 894d7ea6..ff0c11db 100644 --- a/src/qonnx/transformation/fixedpt_quantize.py +++ b/src/qonnx/transformation/fixedpt_quantize.py @@ -48,7 +48,8 @@ class FixedPointQuantizeParamsFromDict(Transformation): data type or its canonical name rounding_mode: Rounding mode used for conversion into fixed point. Default is "ROUND", - possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"] + possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", + "HALF_UP", "HALF_DOWN"] """ def __init__(self, fixedpt_dict, rounding_mode="ROUND"): diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 3a3ce2af..e756366d 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -51,11 +51,19 @@ def get_preferred_onnx_opset(): return 11 +def get_preferred_qonnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 1 + + def qonnx_make_model(graph_proto, **kwargs): "Wrapper around ONNX make_model with preferred qonnx opset version" opset_imports = kwargs.pop("opset_imports", None) if opset_imports is None: - opset_imports = [make_opsetid("", get_preferred_onnx_opset())] + opset_imports = [ + make_opsetid("", get_preferred_onnx_opset()), + make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()), + ] kwargs["opset_imports"] = opset_imports else: kwargs["opset_imports"] = opset_imports diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py index 8eec7156..54b71754 100644 --- a/tests/core/test_custom_onnx_exec.py +++ b/tests/core/test_custom_onnx_exec.py @@ -32,6 +32,8 @@ import qonnx.core.execute_custom_node as ex_cu_node from qonnx.custom_op.registry import getCustomOp +mt_node_version = 1 + def test_execute_custom_node_multithreshold(): inputs = np.ndarray( @@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs execution_context["thresholds"] = threshold_values - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs = np.ndarray( shape=(6, 3, 2, 2), @@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold(): ) graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs_scaled = 2.0 * outputs - 1.0 assert (execution_context["out"] == outputs_scaled).all() @@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs_nhwc graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) assert (execution_context["out"] == outputs_nhwc).all() # check the set of allowed values op_inst = getCustomOp(node_def) diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 722f0fb1..995bcb17 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -68,6 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity + assert model.get_opset_imports() == {"": 8} def test_modelwrapper_set_get_rm_initializer(): diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py new file mode 100644 index 00000000..3efbde24 --- /dev/null +++ b/tests/custom_op/test_customop_version.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import onnx.parser as oprs + +import qonnx.custom_op.general as general +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import getCustomOp + + +class VerTestOp_v1(CustomOp): + def get_nodeattr_types(self): + my_attrs = {"v1_attr": ("i", True, 0)} + return my_attrs + + def make_shape_compatible_op(self, model): + ishape = model.get_tensor_shape(self.onnx_node.input[0]) + return super().make_const_shape_op(ishape) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def execute_node(self, context, graph): + node = self.onnx_node + context[node.output[0]] = context[node.input[0]] + + def verify_node(self): + pass + + +class VerTestOp_v2(VerTestOp_v1): + def get_nodeattr_types(self): + my_attrs = {"v2_attr": ("i", True, 0)} + return my_attrs + + +class VerTestOp_v3(VerTestOp_v2): + def get_nodeattr_types(self): + my_attrs = {"v3_attr": ("i", True, 0)} + return my_attrs + + +def make_vertest_model(vertest_ver, no_opset_import): + ishp = (1, 10) + oshp = ishp + ishp_str = str(list(ishp)) + oshp_str = str(list(oshp)) + if no_opset_import: + opset_import = "" + else: + opset_import = f', "qonnx.custom_op.general" : {vertest_ver}' + input = f""" + < + ir_version: 7, + opset_import: ["" : 9{opset_import}] + > + agraph (float{ishp_str} in0) => (float{oshp_str} out0) + {{ + out0 = qonnx.custom_op.general.VerTestOp< + v{vertest_ver}_attr={vertest_ver} + >(in0) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) + return model + + +def test_customop_version(): + # unspecified version defaults to v1 implementation + general.custom_op["VerTestOp"] = VerTestOp_v1 + # v1 version is also explicitly registered + general.custom_op["VerTestOp_v1"] = VerTestOp_v1 + general.custom_op["VerTestOp_v2"] = VerTestOp_v2 + general.custom_op["VerTestOp_v3"] = VerTestOp_v3 + + # if onnx is lacking the opset import, should default to v1 handler + # (since we set custom_op["VerTestOp"] = VerTestOp_v1) + model = make_vertest_model(1, True) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v1) + # alternatively, when using ModelWrapper.get_customop_wrapper and onnx is + # lacking the opset import, should fall back to the specified version + inst = model.get_customop_wrapper(model.graph.node[0], fallback_customop_version=2) + assert isinstance(inst, VerTestOp_v2) + + for ver in [1, 2, 3]: + model = make_vertest_model(ver, False) + # use ModelWrapper.get_customop_wrapper for implicit + # fetching of op version + inst = model.get_customop_wrapper(model.graph.node[0]) + assert inst.get_nodeattr(f"v{ver}_attr") == ver + assert inst.onnx_opset_version == ver + # explicitly specify onnx_opset_version in getCustomOp + # note: new code should avoid calling getCustomOp directly like this + # and instead use ModelWrapper.get_customop_wrapper + inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) + assert inst.get_nodeattr(f"v{ver}_attr") == ver + assert inst.onnx_opset_version == ver + # unspecified version getCustomOp should default to v1 handler + model = make_vertest_model(1, False) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v1) + # requesting v4 should return largest available version (v3 in this case) + model = make_vertest_model(3, False) + inst = getCustomOp(model.graph.node[0], onnx_opset_version=4) + assert isinstance(inst, VerTestOp_v3) + assert inst.onnx_opset_version == 3