From e2f11b0b94381db5f8669d04a04940db51cea5b8 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 11 Feb 2025 16:54:07 +0000 Subject: [PATCH 01/35] [trunc] Updated Trunc to match the new numerics / export from Brevitas --- src/qonnx/custom_op/general/trunc.py | 38 +++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 8e2eaa19..9d750dcf 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -31,10 +31,10 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.general.quant import resolve_rounding_mode +from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR # Scaling @@ -42,18 +42,25 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding y = y + zeropt # Rounding y = np.round(y) - # Truncate - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0**trunc_bit_width + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case y = y / trunc_scale - # To int + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) rounding_fx = resolve_rounding_mode(rounding_mode) y = rounding_fx(y) # Rescale - y = y - zeropt - y = y * scale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale return y @@ -73,6 +80,13 @@ def get_nodeattr_types(self): return { # The rounding mode, which is used for the trunc function "rounding_mode": ("s", True, "FLOOR"), + "narrow": ("i", False, 0, {0, 1}), + "signed": ("i", False, 1, {0, 1}), + "output_scale": ( + "f", + False, + -1.0, + ), # Invalid scale signifies that it needs to be computed from input/output bit_width } def make_shape_compatible_op(self, model): @@ -93,8 +107,14 @@ def execute_node(self, context, graph): output_bit_width = context[node.input[4]] # save attributes rounding_mode = self.get_nodeattr("rounding_mode") + narrow = self.get_nodeattr("narrow") + signed = self.get_nodeattr("signed") + output_scale = self.get_nodeattr("output_scale") + output_scale = 2 ** (input_bit_width - output_bit_width) if output_scale <= 0.0 else output_scale # calculate output - ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) + ret = trunc( + inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode + ) # set context according to output name context[node.output[0]] = ret From e59177fac6b3f1756b12d8c4ef9e350f36cbf290 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 13 Mar 2025 15:54:05 +0000 Subject: [PATCH 02/35] Update trunc_op description. --- docs/qonnx-custom-ops/trunc_op.md | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_op.md index 1b5f0d04..642760a5 100644 --- a/docs/qonnx-custom-ops/trunc_op.md +++ b/docs/qonnx-custom-ops/trunc_op.md @@ -6,13 +6,20 @@ The attribute rounding_mode defines how truncated values are rounded. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +This operator is not part of the ONNX standard. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2. #### Attributes
rounding_mode : string (default is "FLOOR")
Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
signed : int (default is 1)
+
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+
narrow : int (default is 0)
+
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+
output_scale : float32, tensor(float32) (default is -1.0)
+
The scale factor of the output, either as a global scalar or with a shape matching the number of dimensions of the X tensor. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour in qonnx.custom_ops.general opset version 1.
#### Inputs @@ -91,26 +98,32 @@ from __future__ import unicode_literals import numpy as np -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): - # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Scaling y = inp_tensor / scale y = y + zeropt # Rounding y = np.round(y) - # Truncate - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0 ** trunc_bit_width + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case y = y / trunc_scale - # To int + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) rounding_fx = resolve_rounding_mode(rounding_mode) y = rounding_fx(y) # Rescale - y = y - zeropt - y = y * scale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale return y From b791c7bad1afbd37a6751228428056c739c6bc5f Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 13 Mar 2025 17:30:59 +0000 Subject: [PATCH 03/35] Minor fixes. --- docs/qonnx-custom-ops/trunc_op.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_op.md index 642760a5..51b5e3a4 100644 --- a/docs/qonnx-custom-ops/trunc_op.md +++ b/docs/qonnx-custom-ops/trunc_op.md @@ -18,8 +18,8 @@ The description of this operator in this document corresponds to `qonnx.custom_o
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
narrow : int (default is 0)
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
-
output_scale : float32, tensor(float32) (default is -1.0)
-
The scale factor of the output, either as a global scalar or with a shape matching the number of dimensions of the X tensor. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour in qonnx.custom_ops.general opset version 1.
+
output_scale : float32 (default is -1.0)
+
The scale factor of the output as a scalar. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour of qonnx.custom_ops.general opset version 1.
#### Inputs From c611ae1941b4571751306a143bba21286e9fd7ef Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 13 Mar 2025 17:35:56 +0000 Subject: [PATCH 04/35] Improved formatting in RTD --- src/qonnx/custom_op/general/trunc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 9d750dcf..85cd1db6 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -66,10 +66,10 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_sca class Trunc(CustomOp): - """Generic truncation operation for QONNX. Takes four inputs: - - input tensor to truncate - - the scale - - the zero-point + """Generic truncation operation for QONNX. Takes four inputs: + - input tensor to truncate + - the scale + - the zero-point - the truncation bit-width The output is a tensor of the same shape as the input tensor, with truncated From 05cd37f1354ff5143e60e360812e28eef125db34 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 24 Mar 2025 17:19:35 +0000 Subject: [PATCH 05/35] Feat (trunc): Switch output_scale, output_zero_point to be inputs instead of attributes --- src/qonnx/custom_op/general/trunc.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 85cd1db6..28bfd23f 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -82,11 +82,6 @@ def get_nodeattr_types(self): "rounding_mode": ("s", True, "FLOOR"), "narrow": ("i", False, 0, {0, 1}), "signed": ("i", False, 1, {0, 1}), - "output_scale": ( - "f", - False, - -1.0, - ), # Invalid scale signifies that it needs to be computed from input/output bit_width } def make_shape_compatible_op(self, model): @@ -104,13 +99,13 @@ def execute_node(self, context, graph): scale = context[node.input[1]] zeropt = context[node.input[2]] input_bit_width = context[node.input[3]] - output_bit_width = context[node.input[4]] + output_scale = context[node.input[4]] + output_zeropt = context[node.input[5]] + output_bit_width = context[node.input[6]] # save attributes rounding_mode = self.get_nodeattr("rounding_mode") narrow = self.get_nodeattr("narrow") signed = self.get_nodeattr("signed") - output_scale = self.get_nodeattr("output_scale") - output_scale = 2 ** (input_bit_width - output_bit_width) if output_scale <= 0.0 else output_scale # calculate output ret = trunc( inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode From a30aaf1cb9bf05dd7b8caad49e241daf40f484e7 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 27 Mar 2025 10:46:17 +0000 Subject: [PATCH 06/35] [trunc] Removed redundant output zero-point input --- src/qonnx/custom_op/general/trunc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 28bfd23f..df51bbed 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -100,8 +100,7 @@ def execute_node(self, context, graph): zeropt = context[node.input[2]] input_bit_width = context[node.input[3]] output_scale = context[node.input[4]] - output_zeropt = context[node.input[5]] - output_bit_width = context[node.input[6]] + output_bit_width = context[node.input[5]] # save attributes rounding_mode = self.get_nodeattr("rounding_mode") narrow = self.get_nodeattr("narrow") From 5ecc34962f8f5ebb3da875e114b3d0e4c0e3cac7 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 27 Mar 2025 10:46:49 +0000 Subject: [PATCH 07/35] [trunc] Update docstring --- src/qonnx/custom_op/general/trunc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index df51bbed..36e60b84 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -70,6 +70,7 @@ class Trunc(CustomOp): - input tensor to truncate - the scale - the zero-point + - the truncation scale - the truncation bit-width The output is a tensor of the same shape as the input tensor, with truncated From 7e9f49dac31697a64f207563cb511fe880ac4de4 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 27 Mar 2025 10:49:52 +0000 Subject: [PATCH 08/35] [docs] Updated definition of the trunc operator --- docs/qonnx-custom-ops/trunc_op.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_op.md index 51b5e3a4..d716c6c2 100644 --- a/docs/qonnx-custom-ops/trunc_op.md +++ b/docs/qonnx-custom-ops/trunc_op.md @@ -18,8 +18,6 @@ The description of this operator in this document corresponds to `qonnx.custom_o
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
narrow : int (default is 0)
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
-
output_scale : float32 (default is -1.0)
-
The scale factor of the output as a scalar. The output scale must represent a shift W.R.T. the input scale (i.e., scale) and therefore must be the input scale multiplied by a power-of-2. If output_scale is less-than-or-equal to 0, it is calculated as 2 ** (in_bitwidth - out_bitwidth) to approximately match the behaviour of qonnx.custom_ops.general opset version 1.
#### Inputs @@ -28,11 +26,13 @@ The description of this operator in this document corresponds to `qonnx.custom_o
X (differentiable) : tensor(float32)
input tensor to truncate
scale : float32
-
The scale factor
+
The scale factor at the input of the truncation
zeropt : float32
-
The zero-point
+
The zero-point at the input of the truncation
in_bitwidth : int32
The number of bits used at the input of the truncation
+
out_scale : float32
+
The scale factor of the output of the truncation
out_bitwidth : int32
The number of bits used at the output of the truncation
From 7dfc4b8676b1fc7a7af122be7f97e108dd07e98f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 11:31:20 +0200 Subject: [PATCH 09/35] [Lint] rerun linter, fix errors --- src/qonnx/core/datatype.py | 2 +- src/qonnx/core/modelwrapper.py | 58 +-- src/qonnx/transformation/fixedpt_quantize.py | 20 +- src/qonnx/transformation/general.py | 3 +- tests/core/test_datatypes.py | 378 ++++++++---------- tests/core/test_subgraph_traversal.py | 62 ++- tests/transformation/test_fixedpt_quantize.py | 69 +--- tests/transformation/test_sort_graph.py | 1 + 8 files changed, 292 insertions(+), 301 deletions(-) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index 5b8d0459..e32d30c0 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -288,7 +288,7 @@ def max(self): return signed_max if self._signed else unsigned_max def allowed(self, value): - value_is_integer = (np.round(value) == value) + value_is_integer = np.round(value) == value value_is_bounded = np.logical_and(self.min() <= value, value <= self.max()) return np.logical_and(value_is_integer, value_is_bounded) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 7566ca07..77866f4c 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -128,13 +128,17 @@ def save(self, filename): def analysis(self, analysis_fxn, apply_to_subgraphs=False): """Runs given anaylsis_fxn on this model and return resulting dict.""" - if apply_to_subgraphs == True: - assert "apply_to_subgraphs" in inspect.signature(analysis_fxn), "analysis_fxn must have 'apply_to_subgraphs' argument when apply_to_subgraphs == True" + if apply_to_subgraphs: + assert "apply_to_subgraphs" in inspect.signature( + analysis_fxn + ), "analysis_fxn must have 'apply_to_subgraphs' argument when apply_to_subgraphs == True" return analysis_fxn(self, apply_to_subgraphs) else: return analysis_fxn(self) - def transform_subgraphs(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True): + def transform_subgraphs( + self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True + ): """Applies given Transformation to all subgraphs of this ModelWrapper instance. - make_deepcopy : operates on a new (deep)copy of model. @@ -144,23 +148,27 @@ def transform_subgraphs(self, transformation, make_deepcopy=True, cleanup=True, otherwise postorder traversal is used. """ for node in self.model.graph.node: - transformed_subgraph_attrs = [] - for idx, attr in enumerate(node.attribute): - if attr.type == onnx.AttributeProto.GRAPH: - # this is a subgraph, add it to the list - subgraph = self.make_subgraph_modelwrapper(attr.g) - # apply the transformation to the subgraph - subgraph = subgraph.transform(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal) - # update the new subgraph in the attrubute - transformed_subgraph_attrs.append((idx, onnx.helper.make_attribute(attr.name, subgraph.model.graph))) - # replace the attributes in the node with the transformed subgraph attributes - for idx, new_attr in transformed_subgraph_attrs: - # remove the old attribute - node.attribute.pop(idx) - # add the new attribute - node.attribute.insert(idx, new_attr) - - def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True): + transformed_subgraph_attrs = [] + for idx, attr in enumerate(node.attribute): + if attr.type == onnx.AttributeProto.GRAPH: + # this is a subgraph, add it to the list + subgraph = self.make_subgraph_modelwrapper(attr.g) + # apply the transformation to the subgraph + subgraph = subgraph.transform( + transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal + ) + # update the new subgraph in the attrubute + transformed_subgraph_attrs.append((idx, onnx.helper.make_attribute(attr.name, subgraph.model.graph))) + # replace the attributes in the node with the transformed subgraph attributes + for idx, new_attr in transformed_subgraph_attrs: + # remove the old attribute + node.attribute.pop(idx) + # add the new attribute + node.attribute.insert(idx, new_attr) + + def transform( + self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True + ): """Applies given Transformation repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. @@ -174,8 +182,10 @@ def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_s if self.fix_float64: (transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model) - if apply_to_subgraphs and use_preorder_traversal == False: - transformed_model.transform_subgraphs(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal) + if apply_to_subgraphs and (use_preorder_traversal is False): + transformed_model.transform_subgraphs( + transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal + ) model_was_changed = True while model_was_changed: @@ -184,7 +194,9 @@ def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_s transformed_model.cleanup() if apply_to_subgraphs and use_preorder_traversal: - transformed_model.transform_subgraphs(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal) + transformed_model.transform_subgraphs( + transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal + ) return transformed_model diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py index 0b21c591..f9225719 100644 --- a/src/qonnx/transformation/fixedpt_quantize.py +++ b/src/qonnx/transformation/fixedpt_quantize.py @@ -29,10 +29,10 @@ import numpy as np from warnings import warn +from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper -from qonnx.transformation.base import Transformation from qonnx.custom_op.general.intquant import resolve_rounding_mode -from qonnx.core.datatype import DataType +from qonnx.transformation.base import Transformation def default_op_filter(op): @@ -44,10 +44,12 @@ class FixedPointQuantizeParamsFromDict(Transformation): Quantize model parameters to a given fixed-point representation. The self.max_err dictionary stores the maximum error for each quantized input after calling. Parameters: - fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point data type or its canonical name + fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point + data type or its canonical name rounding_mode: Rounding mode used for conversion into fixed point. Default is "ROUND", - possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"] + possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", + "HALF_UP", "HALF_DOWN"] """ def __init__(self, fixedpt_dict, rounding_mode="ROUND"): @@ -63,13 +65,17 @@ def apply(self, model: ModelWrapper): tdtype = DataType[tdtype] current_dtype = model.get_tensor_datatype(tname) if current_dtype.is_fixed_point(): - warn(f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. Recasting to {tdtype.get_canonical_name()}") + warn( + f"Tensor {tname} is already a {current_dtype.get_canonical_name()} type. " + "Recasting to {tdtype.get_canonical_name()}" + ) in1_t_new = self.round_func(in1_t.astype(np.float32) / tdtype.scale_factor()) * tdtype.scale_factor() if (in1_t_new.max() > tdtype.max()) or (in1_t_new.min() < tdtype.min()): warn( f"Range of {tname} [{in1_t_new.min():.3f}, {in1_t_new.max():.3f}] greater than" - f"{tdtype.get_canonical_name()} [{tdtype.min():.3f}, {tdtype:.max():.3f}], clipping.") + f"{tdtype.get_canonical_name()} [{tdtype.min():.3f}, {tdtype:.max():.3f}], clipping." + ) in1_t_new = np.clip(in1_t_new, tdtype.min(), tdtype.max()) model.set_tensor_datatype(tname, tdtype) model.set_initializer(tname, in1_t_new) @@ -78,6 +84,7 @@ def apply(self, model: ModelWrapper): return (model, False) + class FixedPointQuantizeParams(Transformation): """ Quantize model parameters to a given fixed-point representation. @@ -93,6 +100,7 @@ class FixedPointQuantizeParams(Transformation): Default is "ROUND", possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", "HALF_UP", "HALF_DOWN"] """ + def __init__(self, fixedpt_dtype, op_filter=default_op_filter, rounding_mode="ROUND"): super().__init__() if isinstance(fixedpt_dtype, str): diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 654bee4e..5126bf27 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -261,8 +261,7 @@ def apply(self, model): # check if node inputs are connected to graph inputs or initializers # if so, we can keep the node in the graph for name in n.input: - if util.get_by_name(model.graph.initializer, name) or \ - util.get_by_name(model.graph.input, name): + if util.get_by_name(model.graph.initializer, name) or util.get_by_name(model.graph.input, name): # this node is connected to graph inputs or initializers # so we can keep it in the graph graph_dependencies[node_idx] = set() diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index 0fbd0dea..452c0611 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest + import numpy as np from qonnx.core.datatype import DataType, resolve_datatype @@ -187,249 +188,218 @@ def test_resolve_datatype(input): test_resolve_datatype(DataType["INT32"]) test_resolve_datatype(DataType["FLOAT32"]) + vectorize_details = { "BIPOLAR": [ - np.array([ - [-1, +1, 0], - [ 0, +1, -1], - [+1, 0, -1] - ]), - np.array([ - [True, True, False], - [False, True, True], - [True, False, True] - ], dtype=bool) + np.array([[-1, +1, 0], [0, +1, -1], [+1, 0, -1]]), + np.array([[True, True, False], [False, True, True], [True, False, True]], dtype=bool), ], "BINARY": [ - np.array([ - [-1, +1, 0], - [ 0, +1, -1], - [+1, 0, -1] - ]), - np.array([ - [False, True, True], - [True, True, False], - [True, True, False] - ], dtype=bool) + np.array([[-1, +1, 0], [0, +1, -1], [+1, 0, -1]]), + np.array([[False, True, True], [True, True, False], [True, True, False]], dtype=bool), ], "TERNARY": [ - np.array([ - [-1, +2, +1, 0], - [ 0, +1, +2, -1], - [+2, +1, 0, -1] - ]), - np.array([ - [True, False, True, True], - [True, True, False, True], - [False, True, True, True] - ], dtype=bool) + np.array([[-1, +2, +1, 0], [0, +1, +2, -1], [+2, +1, 0, -1]]), + np.array([[True, False, True, True], [True, True, False, True], [False, True, True, True]], dtype=bool), ], "UINT2": [ - np.array([ - [[-1, +2, +1, 0], - [ 0, +1, +2, -1]], - [[+2, +1, 0, -1], - [+4, -1, -2, +3]], - ]), - np.array([ - [[False, True, True, True], - [True, True, True, False]], - [[True, True, True, False], - [False, False, False, True]], - ], dtype=bool) + np.array( + [ + [[-1, +2, +1, 0], [0, +1, +2, -1]], + [[+2, +1, 0, -1], [+4, -1, -2, +3]], + ] + ), + np.array( + [ + [[False, True, True, True], [True, True, True, False]], + [[True, True, True, False], [False, False, False, True]], + ], + dtype=bool, + ), ], "UINT3": [ - np.array([ - [[+9, -6, +3, 0], - [-4, +4, 0, +1]], - [[-1, +3, +10, +4], - [+2, +6, +7, +8]], - ]), - np.array([ - [[False, False, True, True], - [False, True, True, True]], - [[False, True, False, True], - [True, True, True, False]], - ], dtype=bool) + np.array( + [ + [[+9, -6, +3, 0], [-4, +4, 0, +1]], + [[-1, +3, +10, +4], [+2, +6, +7, +8]], + ] + ), + np.array( + [ + [[False, False, True, True], [False, True, True, True]], + [[False, True, False, True], [True, True, True, False]], + ], + dtype=bool, + ), ], "UINT4": [ - np.array([ - [[-10, -4, +9, +13], - [+1, +14, +11, +4]], - [[+18, -7, +1, +9], - [-1, -7, +1, -2]], - ]), - np.array([ - [[False, False, True, True], - [True, True, True, True]], - [[False, False, True, True], - [False, False, True, False]], - ], dtype=bool) + np.array( + [ + [[-10, -4, +9, +13], [+1, +14, +11, +4]], + [[+18, -7, +1, +9], [-1, -7, +1, -2]], + ] + ), + np.array( + [ + [[False, False, True, True], [True, True, True, True]], + [[False, False, True, True], [False, False, True, False]], + ], + dtype=bool, + ), ], "UINT8": [ - np.array([ - [[148, 61, 70, 29], - [244, 213, 10, 135]], - [[18, 25, 246, 137], - [236, -31, 220, 359]], - ]), - np.array([ - [[True, True, True, True], - [True, True, True, True]], - [[True, True, True, True], - [True, False, True, False]], - ], dtype=bool) + np.array( + [ + [[148, 61, 70, 29], [244, 213, 10, 135]], + [[18, 25, 246, 137], [236, -31, 220, 359]], + ] + ), + np.array( + [ + [[True, True, True, True], [True, True, True, True]], + [[True, True, True, True], [True, False, True, False]], + ], + dtype=bool, + ), ], "UINT16": [ - np.array([ - [[35261, 129491, 9136, 18643], - [128532, -597, 34768, 248]], - [[21646, 30778, 71076, 21224], - [60657, 52854, -5994, 17295]], - ]), - np.array([ - [[True, False, True, True], - [False, False, True, True]], - [[True, True, False, True], - [True, True, False, True]], - ], dtype=bool) + np.array( + [ + [[35261, 129491, 9136, 18643], [128532, -597, 34768, 248]], + [[21646, 30778, 71076, 21224], [60657, 52854, -5994, 17295]], + ] + ), + np.array( + [ + [[True, False, True, True], [False, False, True, True]], + [[True, True, False, True], [True, True, False, True]], + ], + dtype=bool, + ), ], "UINT32": [ - np.array([ - [[-417565331, 3488834022, -1757218812, 591311876], - [1842515574, 4131239283, 2022242400, 1240578991]], - [[609779043, 574774725, 4188472937, 3109757181], - [-767760560, -2100731532, 3794040092, 3223013612]], - ]), - np.array([ - [[False, True, False, True], - [True, True, True, True]], - [[True, True, True, True], - [False, False, True, True]], - ], dtype=bool) + np.array( + [ + [[-417565331, 3488834022, -1757218812, 591311876], [1842515574, 4131239283, 2022242400, 1240578991]], + [[609779043, 574774725, 4188472937, 3109757181], [-767760560, -2100731532, 3794040092, 3223013612]], + ] + ), + np.array( + [ + [[False, True, False, True], [True, True, True, True]], + [[True, True, True, True], [False, False, True, True]], + ], + dtype=bool, + ), ], "INT2": [ - np.array([ - [[ 0, 2, 2, 3], - [-4, 2, -1, 2]], - [[ 1, 2, -4, -1], - [ 2, -1, -1, -2]], - ]), - np.array([ - [[True, False, False, False], - [False, False, True, False]], - [[True, False, False, True], - [False, True, True, True]], - ], dtype=bool) + np.array( + [ + [[0, 2, 2, 3], [-4, 2, -1, 2]], + [[1, 2, -4, -1], [2, -1, -1, -2]], + ] + ), + np.array( + [ + [[True, False, False, False], [False, False, True, False]], + [[True, False, False, True], [False, True, True, True]], + ], + dtype=bool, + ), ], "INT3": [ - np.array([ - [[-4, -6, -7, 3], - [ 2, -8, -7, 3]], - [[-4, -4, 4, -4], - [ 1, -4, 1, -5]], - ]), - np.array([ - [[True, False, False, True], - [True, False, False, True]], - [[True, True, False, True], - [True, True, True, False]], - ], dtype=bool) + np.array( + [ + [[-4, -6, -7, 3], [2, -8, -7, 3]], + [[-4, -4, 4, -4], [1, -4, 1, -5]], + ] + ), + np.array( + [ + [[True, False, False, True], [True, False, False, True]], + [[True, True, False, True], [True, True, True, False]], + ], + dtype=bool, + ), ], "INT4": [ - np.array([ - [[ 5, 9, 3, -6], - [ 1, 5, 9, 10]], - [[ 10, 10, -3, 0], - [ -8, -5, -12, -5]], - ]), - np.array([ - [[True, False, True, True], - [True, True, False, False]], - [[False, False, True, True], - [True, True, False, True]], - ], dtype=bool) + np.array( + [ + [[5, 9, 3, -6], [1, 5, 9, 10]], + [[10, 10, -3, 0], [-8, -5, -12, -5]], + ] + ), + np.array( + [ + [[True, False, True, True], [True, True, False, False]], + [[False, False, True, True], [True, True, False, True]], + ], + dtype=bool, + ), ], "INT8": [ - np.array([ - [[-143, 140, 54, -217], - [ 22, 186, 72, -175]], - [[-126, -6, 115, 240], - [-87, -159, 128, -178]], - ]), - np.array([ - [[False, False, True, False], - [True, False, True, False]], - [[True, True, True, False], - [True, False, False, False]], - ], dtype=bool) + np.array( + [ + [[-143, 140, 54, -217], [22, 186, 72, -175]], + [[-126, -6, 115, 240], [-87, -159, 128, -178]], + ] + ), + np.array( + [ + [[False, False, True, False], [True, False, True, False]], + [[True, True, True, False], [True, False, False, False]], + ], + dtype=bool, + ), ], "INT16": [ - np.array([ - [[ 36863, 2676, 2728, -61500], - [ 24314, 18040, -39438, 64013]], - [[ 28824, -38855, 46308, -50728], - [-50275, -48853, -42034, -44384]], - ]), - np.array([ - [[False, True, True, False], - [True, True, False, False]], - [[True, False, False, False], - [False, False, False, False]], - ], dtype=bool) + np.array( + [ + [[36863, 2676, 2728, -61500], [24314, 18040, -39438, 64013]], + [[28824, -38855, 46308, -50728], [-50275, -48853, -42034, -44384]], + ] + ), + np.array( + [ + [[False, True, True, False], [True, True, False, False]], + [[True, False, False, False], [False, False, False, False]], + ], + dtype=bool, + ), ], "FIXED<4,2>": [ - np.array([ - [[1.8, 1.5, -0.25, 0], - [-1.1, -2, 1.75, 0.1]], - [[-1.5, 1.6, 0.5, 0.1], - [0.4, 0.001, 3.03, 1.75]], - ]), - np.array([ - [[False, True, True, True], - [False, True, True, False]], - [[True, False, True, False], - [False, False, False, True]], - ], dtype=bool) + np.array( + [ + [[1.8, 1.5, -0.25, 0], [-1.1, -2, 1.75, 0.1]], + [[-1.5, 1.6, 0.5, 0.1], [0.4, 0.001, 3.03, 1.75]], + ] + ), + np.array( + [ + [[False, True, True, True], [False, True, True, False]], + [[True, False, True, False], [False, False, False, True]], + ], + dtype=bool, + ), ], "FLOAT<4,3>": [ - np.array([ - [0.0, 0.5, 1.875, -1.5], - [1.8, -512.0, 0.013671875, 0.0087890625], - [0.001953125, 0.0009765625, 2.0, 1.25] - ]), - np.array([ - [True, True, True, True], - [False, False, True, False], - [True, False, True, True] - ]) + np.array( + [[0.0, 0.5, 1.875, -1.5], [1.8, -512.0, 0.013671875, 0.0087890625], [0.001953125, 0.0009765625, 2.0, 1.25]] + ), + np.array([[True, True, True, True], [False, False, True, False], [True, False, True, True]]), ], "FLOAT<4,0>": [ - np.array([ - [0.0, 0.5, 0.75], - [0.015625, 0.0078125, 0.0625] - ]), - np.array([ - [True, True, False], - [True, False, True] - ]) + np.array([[0.0, 0.5, 0.75], [0.015625, 0.0078125, 0.0625]]), + np.array([[True, True, False], [True, False, True]]), ], "FLOAT<4,3,5>": [ - np.array([ - [0.0, 0.5, 1.875], - [-1.5, 1.8, -512.0] - ]), - np.array([ - [True, True, True], - [True, False, True] - ]) + np.array([[0.0, 0.5, 1.875], [-1.5, 1.8, -512.0]]), + np.array([[True, True, True], [True, False, True]]), ], - "FLOAT<4,0,5>": [ - np.array([0.0, 0.0625, 0.03125]), - np.array([True, True, False]) - ] + "FLOAT<4,0,5>": [np.array([0.0, 0.0625, 0.03125]), np.array([True, True, False])], } + @pytest.mark.parametrize("datatype", vectorize_details.keys()) def test_vectorized_allowed(datatype): input_values, golden_out = vectorize_details[datatype] diff --git a/tests/core/test_subgraph_traversal.py b/tests/core/test_subgraph_traversal.py index 3e8121c2..15d7b1a5 100644 --- a/tests/core/test_subgraph_traversal.py +++ b/tests/core/test_subgraph_traversal.py @@ -1,12 +1,13 @@ import pytest + +import onnx from collections import Counter +from onnx import helper from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation +from qonnx.util.basic import get_by_name, qonnx_make_model -from qonnx.util.basic import qonnx_make_model, get_by_name -import onnx -from onnx import helper # Helper to recursively build a graph with subgraphs attached to nodes def make_graph(tree): @@ -49,6 +50,7 @@ def make_graph(tree): return graph + def make_subgraph_model(tree): """ Build a ModelWrapper with a graph structure based on the provided tree. @@ -73,7 +75,9 @@ def apply(self, model_wrapper): dummy_name_in = f"{graph_name}_dummy_in" dummy_name_out = f"{graph_name}_dummy_out" model_wrapper.model.graph.input.append(helper.make_tensor_value_info(dummy_name_in, onnx.TensorProto.FLOAT, [4, 4])) - model_wrapper.model.graph.output.append(helper.make_tensor_value_info(dummy_name_out, onnx.TensorProto.FLOAT, [4, 4])) + model_wrapper.model.graph.output.append( + helper.make_tensor_value_info(dummy_name_out, onnx.TensorProto.FLOAT, [4, 4]) + ) model_wrapper.model.graph.node.append( helper.make_node( "DummyNode", # dummy op_type @@ -85,15 +89,18 @@ def apply(self, model_wrapper): # collect the name of the graph being transformed to check how many times each graph was visited self.visited.append(graph_name) - #import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() return model_wrapper, False + class NestedTransform(Transformation): def __init__(self): self.dummy_transform = DummyTransform() + def apply(self, model_wrapper): return model_wrapper.transform(self.dummy_transform), False + def get_subgraph_names(tree): """ Recursively collect the names of all subgraphs in the tree structure. @@ -115,10 +122,11 @@ def check_all_visted_once(tree, transform): """ Check that all subgraphs in the tree structure were visited exactly once. """ - visited = transform.visited + visited = transform.visited expected = get_subgraph_names(tree) assert Counter(visited) == Counter(expected), f"Visited: {visited}, Expected: {expected}" + def check_visit_order(tree, transform, order): """ Check that the order of visited subgraphs matches the expected preorder or postorder traversal. @@ -127,6 +135,7 @@ def check_visit_order(tree, transform, order): expected = order(tree) assert visited == expected, f"Visited: {visited}, Expected: {expected}" + def check_all_subgraphs_transformed(graph): """ Check that all subgraphs in the tree structure have been transformed. @@ -149,20 +158,20 @@ def get_metadata_props(graph, key): else: return metadata_prop.value - assert(get_metadata_props(graph, graph.name) == "visited"), f"Metadata for {graph.name} not set correctly" - assert(get_metadata_props(graph, "opset_id") == "10"), "Metadata for opset_id not set correctly" + assert get_metadata_props(graph, graph.name) == "visited", f"Metadata for {graph.name} not set correctly" + assert get_metadata_props(graph, "opset_id") == "10", "Metadata for opset_id not set correctly" # recursively check all subgraphs for node in graph.node: - for attr in node.attribute: + for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: check_all_subgraphs_transformed(attr.g) + @pytest.mark.parametrize("cleanup", [False, True]) @pytest.mark.parametrize("make_deepcopy", [False, True]) -@pytest.mark.parametrize("tree, apply_to_subgraphs", - [(("top", []), True), - (("top", []), False), - (("top", [("sub1", [])]), False)]) +@pytest.mark.parametrize( + "tree, apply_to_subgraphs", [(("top", []), True), (("top", []), False), (("top", [("sub1", [])]), False)] +) def test_no_traversal(tree, cleanup, make_deepcopy, apply_to_subgraphs): # Check that the top-level model is transformed exactly once when there are no subgraphs. # Check that the top-level model is transformed exactly once when there are subgraphs, but apply_to_subgraphs is False. @@ -175,6 +184,7 @@ def test_no_traversal(tree, cleanup, make_deepcopy, apply_to_subgraphs): assert transform.visited == ["top"] assert t_model.get_metadata_prop("top") == "visited" + def build_preorder_traversal(tree): """ Build a preorder traversal of the tree structure. @@ -190,6 +200,7 @@ def traverse(node): traverse(tree) return traversal + def build_postorder_traversal(tree): """ Build a postorder traversal of the tree structure. @@ -205,10 +216,16 @@ def traverse(node): traverse(tree) return traversal + @pytest.mark.parametrize("cleanup", [False, True]) @pytest.mark.parametrize("make_deepcopy", [False, True]) -@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]), - ("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])]) +@pytest.mark.parametrize( + "tree", + [ + ("top", [("sub1", []), ("sub2", [])]), + ("top", [("sub1", [("sub1_1", []), ("sub1_2", [])]), ("sub2", [("sub2_1", [])])]), + ], +) @pytest.mark.parametrize("use_preorder_traversal", [True, False]) def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal): # Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True. @@ -216,7 +233,9 @@ def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal): print(f"Testing tree: {tree}, cleanup: {cleanup}, make_deepcopy: {make_deepcopy}") model = make_subgraph_model(tree) transform = DummyTransform() - t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs=True, use_preorder_traversal=use_preorder_traversal) + t_model = model.transform( + transform, cleanup, make_deepcopy, apply_to_subgraphs=True, use_preorder_traversal=use_preorder_traversal + ) check_all_visted_once(tree, transform) check_all_subgraphs_transformed(t_model.model.graph) @@ -230,8 +249,13 @@ def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal): @pytest.mark.parametrize("cleanup", [False, True]) @pytest.mark.parametrize("make_deepcopy", [False, True]) -@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]), - ("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])]) +@pytest.mark.parametrize( + "tree", + [ + ("top", [("sub1", []), ("sub2", [])]), + ("top", [("sub1", [("sub1_1", []), ("sub1_2", [])]), ("sub2", [("sub2_1", [])])]), + ], +) def test_traversal_nested(tree, cleanup, make_deepcopy): # Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True. # This should always be done correctly regardless of cleanup and make_deepcopy. @@ -242,6 +266,7 @@ def test_traversal_nested(tree, cleanup, make_deepcopy): check_all_visted_once(tree, transform.dummy_transform) check_all_subgraphs_transformed(t_model.model.graph) + def dummy_analysis_fxn(model_wrapper): """ A dummy analysis function that simply returns the model wrapper. @@ -250,6 +275,7 @@ def dummy_analysis_fxn(model_wrapper): d = {} return d + @pytest.mark.xfail(reason="Analysis functions require apply_to_subgraphs when traversing subgraphs") def test_analysis_fxn_without_apply_to_subgraphs_fails(): # Check that an analysis function fails when apply_to_subgraphs is False diff --git a/tests/transformation/test_fixedpt_quantize.py b/tests/transformation/test_fixedpt_quantize.py index 285e87f8..2b60c735 100644 --- a/tests/transformation/test_fixedpt_quantize.py +++ b/tests/transformation/test_fixedpt_quantize.py @@ -28,16 +28,14 @@ import pytest -import numpy as np import os +from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.fixedpt_quantize import FixedPointQuantizeParams, FixedPointQuantizeParamsFromDict -from qonnx.core.datatype import DataType from qonnx.util.cleanup import cleanup_model from qonnx.util.test import download_model - fixedpt_dict_details = { "Conv_bias_example_round": { "test_model": "Conv_bias_example", @@ -47,9 +45,9 @@ "Conv_1_param0": "FIXED<8,1>", "Conv_1_param1": "FIXED<8,1>", "Gemm_0_param0": "FIXED<12,1>", - "Gemm_0_param1": "FIXED<12,1>" + "Gemm_0_param1": "FIXED<12,1>", }, - "rounding_mode": "ROUND" + "rounding_mode": "ROUND", }, "Conv_bias_example_floor": { "test_model": "Conv_bias_example", @@ -59,9 +57,9 @@ "Conv_1_param0": "FIXED<8,1>", "Conv_1_param1": "FIXED<8,1>", "Gemm_0_param0": "FIXED<12,1>", - "Gemm_0_param1": "FIXED<12,1>" + "Gemm_0_param1": "FIXED<12,1>", }, - "rounding_mode": "FLOOR" + "rounding_mode": "FLOOR", }, "FINN-CNV_W2A2_round": { "test_model": "FINN-CNV_W2A2", @@ -97,9 +95,9 @@ "BatchNormalization_7_param0": "FIXED<9,4>", "BatchNormalization_7_param1": "FIXED<10,3>", "BatchNormalization_7_param2": "FIXED<12,8>", - "BatchNormalization_7_param3": "FIXED<14,13>" + "BatchNormalization_7_param3": "FIXED<14,13>", }, - "rounding_mode": "ROUND" + "rounding_mode": "ROUND", }, "FINN-CNV_W2A2_floor": { "test_model": "FINN-CNV_W2A2", @@ -135,9 +133,9 @@ "BatchNormalization_7_param0": "FIXED<9,4>", "BatchNormalization_7_param1": "FIXED<10,3>", "BatchNormalization_7_param2": "FIXED<12,8>", - "BatchNormalization_7_param3": "FIXED<14,13>" + "BatchNormalization_7_param3": "FIXED<14,13>", }, - "rounding_mode": "FLOOR" + "rounding_mode": "FLOOR", }, "MobileNetv1-w4a4_round": { "test_model": "MobileNetv1-w4a4", @@ -249,9 +247,9 @@ "BatchNormalization_26_param0": "FIXED<10,3>", "BatchNormalization_26_param1": "FIXED<5,2>", "BatchNormalization_26_param2": "FIXED<4,2>", - "BatchNormalization_26_param3": "FIXED<11,1>" + "BatchNormalization_26_param3": "FIXED<11,1>", }, - "rounding_mode": "ROUND" + "rounding_mode": "ROUND", }, "MobileNetv1-w4a4_floor": { "test_model": "MobileNetv1-w4a4", @@ -363,10 +361,10 @@ "BatchNormalization_26_param0": "FIXED<10,3>", "BatchNormalization_26_param1": "FIXED<5,2>", "BatchNormalization_26_param2": "FIXED<4,2>", - "BatchNormalization_26_param3": "FIXED<11,1>" + "BatchNormalization_26_param3": "FIXED<11,1>", }, - "rounding_mode": "FLOOR" - } + "rounding_mode": "FLOOR", + }, } @@ -401,67 +399,44 @@ def test_fixedpt_quantize_from_dict(test_case): os.unlink(dl_file) + fixedpt_details = { "FINN-CNV_W2A2_round_0": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<8,3>", "rounding_mode": "ROUND", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_floor_0": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<8,3>", "rounding_mode": "FLOOR", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_round_1": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<4,3>", "rounding_mode": "ROUND", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_floor_1": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<4,3>", "rounding_mode": "FLOOR", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_round_2": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<12,3>", "rounding_mode": "ROUND", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], }, "FINN-CNV_W2A2_floor_2": { "test_model": "FINN-CNV_W2A2", "dtype": "FIXED<12,3>", "rounding_mode": "FLOOR", - "quant_tensors": [ - "Mul_0_param0", - "Mul_1_param0", - "Add_0_param0" - ] - } + "quant_tensors": ["Mul_0_param0", "Mul_1_param0", "Add_0_param0"], + }, } diff --git a/tests/transformation/test_sort_graph.py b/tests/transformation/test_sort_graph.py index 876e2a4b..cb9fd072 100644 --- a/tests/transformation/test_sort_graph.py +++ b/tests/transformation/test_sort_graph.py @@ -167,6 +167,7 @@ def test_sort_nonlinear_graph(): # plt.plot(sizes,times,"--o") # plt.grid(True) + def test_sort_graph_node_only_connected_to_graphio(): """ Test that SortGraph does not remove nodes that are only connected to graph inputs/outputs. From 7456919c4e614919aa3003bab6ce1e9c55f1300f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 11:32:08 +0200 Subject: [PATCH 10/35] [Core] add get_opset_imports utility fxn to ModelWrapper --- src/qonnx/core/modelwrapper.py | 4 ++++ tests/core/test_modelwrapper.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 77866f4c..a85a1cf0 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -737,3 +737,7 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qa.tensor_name = tensor_name qa.quant_parameter_tensor_names.append(dt) qnt_annotations.append(qa) + + def get_opset_imports(self): + """Returns a list of imported opsets as (domain, version) tuples.""" + return [(opset.domain, opset.version) for opset in self._model_proto.opset_import] diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 722f0fb1..5ffabb3c 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -68,6 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity + assert model.get_opset_imports() == [("", 8)] def test_modelwrapper_set_get_rm_initializer(): From 89396cde144cb09fab87b5e8f5f93e599c6e4bf3 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:08:03 +0200 Subject: [PATCH 11/35] [Core] return dict from ModelWrapper.get_opset_imports --- src/qonnx/core/modelwrapper.py | 4 ++-- tests/core/test_modelwrapper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index a85a1cf0..6248ac6f 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -739,5 +739,5 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): qnt_annotations.append(qa) def get_opset_imports(self): - """Returns a list of imported opsets as (domain, version) tuples.""" - return [(opset.domain, opset.version) for opset in self._model_proto.opset_import] + """Returns a list of imported opsets as a {domain, version} dictionary.""" + return {opset.domain: opset.version for opset in self._model_proto.opset_import} diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 5ffabb3c..995bcb17 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -68,7 +68,7 @@ def test_modelwrapper(): inp_sparsity = {"dw": {"kernel_shape": [3, 3]}} model.set_tensor_sparsity(first_conv_iname, inp_sparsity) assert model.get_tensor_sparsity(first_conv_iname) == inp_sparsity - assert model.get_opset_imports() == [("", 8)] + assert model.get_opset_imports() == {"": 8} def test_modelwrapper_set_get_rm_initializer(): From db2994f82227f252f4470e4dbdcbd53bdf579fed Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:12:31 +0200 Subject: [PATCH 12/35] [Core] add versioned op to getCustomOp with fallback to old style --- src/qonnx/custom_op/registry.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..5d1a52ca 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -41,7 +41,15 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex try: opset_module = importlib.import_module(domain) assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - inst_wrapper = opset_module.custom_op[op_type] + op_type_with_version = op_type + "_v" + str(onnx_opset_version) + # TODO version handling: use highest available version smaller than requested version + # when the exact match is not found + if op_type_with_version in opset_module.custom_op: + # priority: if it exists, load the versioned CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type_with_version] + else: + # otherwise use the default (non-suffixed) CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type] inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) return inst except ModuleNotFoundError: From 8a2db226a8a3efed8f07c41c9da0de8b943e0f7d Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:13:46 +0200 Subject: [PATCH 13/35] [Core] inrtoduce ModelWrapper.get_customop_wrapper grabs CustomOp instance with the right opset version from protobuf imported opsets --- src/qonnx/core/modelwrapper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 6248ac6f..b2308a06 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -39,6 +39,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -741,3 +742,10 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict): def get_opset_imports(self): """Returns a list of imported opsets as a {domain, version} dictionary.""" return {opset.domain: opset.version for opset in self._model_proto.opset_import} + + def get_customop_wrapper(self, node): + """Return CustomOp instance for given node, respecting the + imported opset version in the model protobuf.""" + opset_imports = self.get_opset_imports() + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) From 402a58056f8af7784065c74cab7b6e58f0e44b4f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 25 Sep 2025 21:14:48 +0200 Subject: [PATCH 14/35] [Test] add basic unit tests for versioned custom op fetching --- tests/custom_op/test_customop_version.py | 106 +++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 tests/custom_op/test_customop_version.py diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py new file mode 100644 index 00000000..bdc660f9 --- /dev/null +++ b/tests/custom_op/test_customop_version.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import onnx.parser as oprs + +import qonnx.custom_op.general as general +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import getCustomOp + + +class VerTestOp_v1(CustomOp): + def get_nodeattr_types(self): + my_attrs = {"v1_attr": ("i", True, 0)} + return my_attrs + + def make_shape_compatible_op(self, model): + ishape = model.get_tensor_shape(self.onnx_node.input[0]) + return super().make_const_shape_op(ishape) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def execute_node(self, context, graph): + node = self.onnx_node + context[node.output[0]] = context[node.input[0]] + + def verify_node(self): + pass + + +class VerTestOp_v2(VerTestOp_v1): + def get_nodeattr_types(self): + my_attrs = {"v2_attr": ("i", True, 0)} + return my_attrs + + +class VerTestOp_v3(VerTestOp_v2): + def get_nodeattr_types(self): + my_attrs = {"v3_attr": ("i", True, 0)} + return my_attrs + + +def make_vertest_model(vertest_ver): + ishp = (1, 10) + oshp = ishp + ishp_str = str(list(ishp)) + oshp_str = str(list(oshp)) + input = f""" + < + ir_version: 7, + opset_import: ["" : 9, "qonnx.custom_op.general" : {vertest_ver}] + > + agraph (float{ishp_str} in0) => (float{oshp_str} out0) + {{ + out0 = qonnx.custom_op.general.VerTestOp< + v{vertest_ver}_attr={vertest_ver} + >(in0) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) + return model + + +def test_customop_version(): + general.custom_op["VerTestOp"] = VerTestOp_v1 + general.custom_op["VerTestOp_v2"] = VerTestOp_v2 + general.custom_op["VerTestOp_v3"] = VerTestOp_v3 + for ver in [1, 2, 3]: + model = make_vertest_model(ver) + # explicitly specify onnx_opset_version in getCustomOp + inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) + assert inst.get_nodeattr(f"v{ver}_attr") == ver + # now use ModelWrapper.get_customop_wrapper for implicit + # fetching of op version + inst = model.get_customop_wrapper(model.graph.node[0]) + assert inst.get_nodeattr(f"v{ver}_attr") == ver From 407fb13c438b47b95205f77101fec9455faaebaf Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 13:30:11 +0200 Subject: [PATCH 15/35] [Test] extend test_customop_version for default v handler --- tests/custom_op/test_customop_version.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py index bdc660f9..79974037 100644 --- a/tests/custom_op/test_customop_version.py +++ b/tests/custom_op/test_customop_version.py @@ -92,7 +92,10 @@ def make_vertest_model(vertest_ver): def test_customop_version(): + # unspecified version defaults to v1 implementation general.custom_op["VerTestOp"] = VerTestOp_v1 + # v1 version is also explicitly registered + general.custom_op["VerTestOp_v1"] = VerTestOp_v1 general.custom_op["VerTestOp_v2"] = VerTestOp_v2 general.custom_op["VerTestOp_v3"] = VerTestOp_v3 for ver in [1, 2, 3]: @@ -104,3 +107,7 @@ def test_customop_version(): # fetching of op version inst = model.get_customop_wrapper(model.graph.node[0]) assert inst.get_nodeattr(f"v{ver}_attr") == ver + # unspecified version getCustomOp should default to v1 handler + # (even though the node is actually v3 in this case) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v1) From feac9f09bcc5dce9137d64d6309504d3e8fd46d4 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 14:27:10 +0200 Subject: [PATCH 16/35] [Core] opset ver. fallback for ModelWrapper.get_customop_wrapper --- src/qonnx/core/modelwrapper.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index b2308a06..3d3ba0e9 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -743,9 +743,19 @@ def get_opset_imports(self): """Returns a list of imported opsets as a {domain, version} dictionary.""" return {opset.domain: opset.version for opset in self._model_proto.opset_import} - def get_customop_wrapper(self, node): + def get_customop_wrapper(self, node, fallback_customop_version=1): """Return CustomOp instance for given node, respecting the - imported opset version in the model protobuf.""" + imported opset version in the model protobuf. If the node's domain + is not found in the model's opset imports, fallback_customop_version + will be used.""" opset_imports = self.get_opset_imports() - opset_import = opset_imports[node.domain] - return getCustomOp(node, onnx_opset_version=opset_import) + try: + opset_import = opset_imports[node.domain] + return getCustomOp(node, onnx_opset_version=opset_import) + except KeyError: + # domain not found in imports, use fallback version + warnings.warn( + f"Domain {node.domain} not found in model opset imports, " + f"using fallback_customop_version={fallback_customop_version}" + ) + return getCustomOp(node, onnx_opset_version=fallback_customop_version) From 89eea4cfb7a8e421ed5c22ba814fc360f4da4b48 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 14:28:01 +0200 Subject: [PATCH 17/35] [Core] getCustomOp: default v to None, fetch highest available v. --- src/qonnx/custom_op/registry.py | 39 ++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 5d1a52ca..825c7566 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,11 +28,12 @@ import importlib -from qonnx.util.basic import get_preferred_onnx_opset - -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - "Return a QONNX CustomOp instance for the given ONNX node, if it exists." +def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): + "Return a QONNX CustomOp wrapper for the given ONNX node and given opset version," + "if it exists. If opset version is None, the default handler for the op type will be used. " + "If version is specified but the exact version match isn't available, the highest available version " + "smaller than the requested version will be used." op_type = node.op_type domain = node.domain if brevitas_exception: @@ -41,18 +42,30 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex try: opset_module = importlib.import_module(domain) assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - op_type_with_version = op_type + "_v" + str(onnx_opset_version) - # TODO version handling: use highest available version smaller than requested version - # when the exact match is not found - if op_type_with_version in opset_module.custom_op: - # priority: if it exists, load the versioned CustomOp wrapper - inst_wrapper = opset_module.custom_op[op_type_with_version] - else: - # otherwise use the default (non-suffixed) CustomOp wrapper + if onnx_opset_version is None: inst_wrapper = opset_module.custom_op[op_type] + else: + op_type_with_version = op_type + "_v" + str(onnx_opset_version) + if op_type_with_version in opset_module.custom_op: + # priority: if it exists, load the versioned CustomOp wrapper + inst_wrapper = opset_module.custom_op[op_type_with_version] + else: + # when the exact version match is not found + # version handling: use highest available version smaller than requested version + available_versions = [ + int(k.split("_v")[-1]) for k in opset_module.custom_op.keys() if k.startswith(op_type + "_v") + ] + suitable_versions = [v for v in available_versions if v <= onnx_opset_version] + if suitable_versions: + highest_version = max(suitable_versions) + inst_wrapper = opset_module.custom_op[f"{op_type}_v{highest_version}"] + else: + raise Exception( + "Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain) + ) inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) return inst except ModuleNotFoundError: raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) except KeyError: - raise Exception("Op %s not found in custom opset %s" % (op_type, domain)) + raise Exception("Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain)) From ec517b517a47db680180586c8063613431297745 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 14:31:15 +0200 Subject: [PATCH 18/35] [Test] cover newly added opset ver behavior in test_customop_version --- tests/custom_op/test_customop_version.py | 37 +++++++++++++++++++----- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py index 79974037..5364df61 100644 --- a/tests/custom_op/test_customop_version.py +++ b/tests/custom_op/test_customop_version.py @@ -69,15 +69,19 @@ def get_nodeattr_types(self): return my_attrs -def make_vertest_model(vertest_ver): +def make_vertest_model(vertest_ver, no_opset_import): ishp = (1, 10) oshp = ishp ishp_str = str(list(ishp)) oshp_str = str(list(oshp)) + if no_opset_import: + opset_import = "" + else: + opset_import = f', "qonnx.custom_op.general" : {vertest_ver}' input = f""" < ir_version: 7, - opset_import: ["" : 9, "qonnx.custom_op.general" : {vertest_ver}] + opset_import: ["" : 9{opset_import}] > agraph (float{ishp_str} in0) => (float{oshp_str} out0) {{ @@ -98,16 +102,33 @@ def test_customop_version(): general.custom_op["VerTestOp_v1"] = VerTestOp_v1 general.custom_op["VerTestOp_v2"] = VerTestOp_v2 general.custom_op["VerTestOp_v3"] = VerTestOp_v3 + + # if onnx is lacking the opset import, should default to v1 handler + # (since we set custom_op["VerTestOp"] = VerTestOp_v1) + model = make_vertest_model(1, True) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v1) + # alternatively, when using ModelWrapper.get_customop_wrapper and onnx is + # lacking the opset import, should fall back to the specified version + inst = model.get_customop_wrapper(model.graph.node[0], fallback_customop_version=2) + assert isinstance(inst, VerTestOp_v2) + for ver in [1, 2, 3]: - model = make_vertest_model(ver) - # explicitly specify onnx_opset_version in getCustomOp - inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) - assert inst.get_nodeattr(f"v{ver}_attr") == ver - # now use ModelWrapper.get_customop_wrapper for implicit + model = make_vertest_model(ver, False) + # use ModelWrapper.get_customop_wrapper for implicit # fetching of op version inst = model.get_customop_wrapper(model.graph.node[0]) assert inst.get_nodeattr(f"v{ver}_attr") == ver + # explicitly specify onnx_opset_version in getCustomOp + # note: new code should avoid calling getCustomOp directly like this + # and instead use ModelWrapper.get_customop_wrapper + inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) + assert inst.get_nodeattr(f"v{ver}_attr") == ver # unspecified version getCustomOp should default to v1 handler - # (even though the node is actually v3 in this case) + model = make_vertest_model(1, False) inst = getCustomOp(model.graph.node[0]) assert isinstance(inst, VerTestOp_v1) + # requesting v4 should return largest available version (v3 in this case) + model = make_vertest_model(3, False) + inst = getCustomOp(model.graph.node[0], onnx_opset_version=4) + assert isinstance(inst, VerTestOp_v3) From aeeff580d8bc1fa53f910474291997506bade759 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:02:17 +0200 Subject: [PATCH 19/35] [Core, Util] distinguish preferred onnx opset from qonnx opset --- src/qonnx/core/modelwrapper.py | 2 +- src/qonnx/util/basic.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index fa5a968d..2ba2984a 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -743,7 +743,7 @@ def get_opset_imports(self): """Returns a list of imported opsets as a {domain, version} dictionary.""" return {opset.domain: opset.version for opset in self._model_proto.opset_import} - def get_customop_wrapper(self, node, fallback_customop_version=1): + def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()): """Return CustomOp instance for given node, respecting the imported opset version in the model protobuf. If the node's domain is not found in the model's opset imports, fallback_customop_version diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 3a3ce2af..e756366d 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -51,11 +51,19 @@ def get_preferred_onnx_opset(): return 11 +def get_preferred_qonnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 1 + + def qonnx_make_model(graph_proto, **kwargs): "Wrapper around ONNX make_model with preferred qonnx opset version" opset_imports = kwargs.pop("opset_imports", None) if opset_imports is None: - opset_imports = [make_opsetid("", get_preferred_onnx_opset())] + opset_imports = [ + make_opsetid("", get_preferred_onnx_opset()), + make_opsetid("qonnx.custom_op.general", get_preferred_qonnx_opset()), + ] kwargs["opset_imports"] = opset_imports else: kwargs["opset_imports"] = opset_imports From 580150453f16bb8aa08ff4de1eee602f375b323e Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:03:21 +0200 Subject: [PATCH 20/35] [Core] respect selected opsets during execution --- src/qonnx/core/execute_custom_node.py | 3 +-- src/qonnx/core/onnx_exec.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/qonnx/core/execute_custom_node.py b/src/qonnx/core/execute_custom_node.py index 7acf3792..cd6bb605 100644 --- a/src/qonnx/core/execute_custom_node.py +++ b/src/qonnx/core/execute_custom_node.py @@ -27,10 +27,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import qonnx.custom_op.registry as registry -from qonnx.util.basic import get_preferred_onnx_opset -def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()): +def execute_custom_node(node, context, graph, onnx_opset_version): """Call custom implementation to execute a single custom node. Input/output provided via context.""" op_type = node.op_type diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a8f4774c..ecb808be 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -36,7 +36,7 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node from qonnx.util.basic import ( - get_preferred_onnx_opset, + get_preferred_qonnx_opset, get_sanitize_quant_tensors, is_finn_op, qonnx_make_model, @@ -44,7 +44,7 @@ ) -def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()): +def execute_node(node, context, graph, opset_version, return_full_exec_context=False): """Executes a single node by using onnxruntime or with a custom function. Input/output provided via context.""" @@ -158,7 +158,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N model_exec_mode = model.get_metadata_prop("exec_mode") if (model_exec_mode is None) or (model_exec_mode == ""): # extract opset version for node-by-node execution - opset_version = model.model.opset_import[0].version + opset_imports = model.get_opset_imports() # execute the model node by node # we can simply walk down the list since the ONNX spec guarantees that it is # topologically sorted @@ -176,7 +176,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N if get_sanitize_quant_tensors() != 0: # round input values to match quantization annotation execution_context = sanitize_quant_values(model, node.input, execution_context) - execute_node(node, execution_context, graph, return_full_exec_context, opset_version) + if node.domain in opset_imports: + opset_version = opset_imports[node.domain] + else: + opset_version = get_preferred_qonnx_opset() + execute_node(node, execution_context, graph, opset_version, return_full_exec_context) if get_sanitize_quant_tensors() != 0: # round output values to quantization annotation execution_context = sanitize_quant_values(model, node.output, execution_context) From 35b8b12d919ae37fe94c317a5585e61fa705b5b6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:03:42 +0200 Subject: [PATCH 21/35] [CustomOp] alias all qonnx.custom_op.general as v1 --- src/qonnx/custom_op/general/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 9b14ea8a..e125cbf8 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -52,3 +52,16 @@ custom_op["Trunc"] = Trunc custom_op["BipolarQuant"] = BipolarQuant custom_op["FloatQuant"] = FloatQuant + +custom_op["DebugMarker_v1"] = DebugMarker +custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d +custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC +custom_op["GenericPartition_v1"] = GenericPartition +custom_op["MultiThreshold_v1"] = MultiThreshold +custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul +custom_op["Im2Col_v1"] = Im2Col +custom_op["IntQuant_v1"] = IntQuant +custom_op["Quant_v1"] = IntQuant +custom_op["Trunc_v1"] = Trunc +custom_op["BipolarQuant_v1"] = BipolarQuant +custom_op["FloatQuant_v1"] = FloatQuant From d190a69d813666c1bc55e42a895b84e026630e8a Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:04:42 +0200 Subject: [PATCH 22/35] [ChanLast] alias existing channels_last ops as v1 --- src/qonnx/custom_op/channels_last/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..60aac1a4 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -7,3 +7,7 @@ custom_op["Conv"] = Conv custom_op["MaxPool"] = MaxPool custom_op["BatchNormalization"] = BatchNormalization + +custom_op["Conv_v1"] = Conv +custom_op["MaxPool_v1"] = MaxPool +custom_op["BatchNormalization_v1"] = BatchNormalization From 5f58f49dbc4110b80f013b919209c758f64dae1b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 2 Oct 2025 22:49:05 +0200 Subject: [PATCH 23/35] [Test] add opsets for test_custom_onnx_exec --- tests/core/test_custom_onnx_exec.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py index 8eec7156..54b71754 100644 --- a/tests/core/test_custom_onnx_exec.py +++ b/tests/core/test_custom_onnx_exec.py @@ -32,6 +32,8 @@ import qonnx.core.execute_custom_node as ex_cu_node from qonnx.custom_op.registry import getCustomOp +mt_node_version = 1 + def test_execute_custom_node_multithreshold(): inputs = np.ndarray( @@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs execution_context["thresholds"] = threshold_values - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs = np.ndarray( shape=(6, 3, 2, 2), @@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold(): ) graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs_scaled = 2.0 * outputs - 1.0 assert (execution_context["out"] == outputs_scaled).all() @@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs_nhwc graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) assert (execution_context["out"] == outputs_nhwc).all() # check the set of allowed values op_inst = getCustomOp(node_def) From db0b15a1f01bdddcb51eaec05131084d6ab9cc49 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 3 Oct 2025 16:34:44 +0200 Subject: [PATCH 24/35] [ChanLast] emulate op ver agnostic dict for channels last ops --- src/qonnx/custom_op/channels_last/__init__.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 60aac1a4..f5033e9b 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -2,12 +2,24 @@ from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -custom_op = dict() +# channels-last ops are defined by the underlying ONNX standard op +# thus, we can define them for any version of the original op +# so we emulate a custom op dictionary that mimics the support for any +# {ChannelsLastOp}_vX instead of hardcoding what versions are supported -custom_op["Conv"] = Conv -custom_op["MaxPool"] = MaxPool -custom_op["BatchNormalization"] = BatchNormalization -custom_op["Conv_v1"] = Conv -custom_op["MaxPool_v1"] = MaxPool -custom_op["BatchNormalization_v1"] = BatchNormalization +class ChannelsLastCustomOpDict: + def __init__(self): + self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization} + + def __getitem__(self, key): + base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13) + if base_key in self._custom_ops: + return self._custom_ops[base_key] + raise KeyError(f"Channels-last CustomOp '{key}' not found.") + + def keys(self): + return self._custom_ops.keys() + + +custom_op = ChannelsLastCustomOpDict() From 83c53aef80b42ca1528120a932eba64c5d78d3e0 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 3 Oct 2025 17:04:36 +0200 Subject: [PATCH 25/35] [Core] use isinstance instead of type check for custom_op --- src/qonnx/custom_op/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 825c7566..442089c3 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -41,7 +41,7 @@ def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") try: opset_module = importlib.import_module(domain) - assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain + assert isinstance(opset_module.custom_op, dict), "custom_op dict not found in Python module %s" % domain if onnx_opset_version is None: inst_wrapper = opset_module.custom_op[op_type] else: From 6bfc2a181b5fad34bf30696a15eef4abbb9f3e06 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 3 Oct 2025 17:05:20 +0200 Subject: [PATCH 26/35] [ChanLast] derive fake custom_op from dict, ensure domain import --- src/qonnx/custom_op/channels_last/__init__.py | 2 +- src/qonnx/transformation/channels_last.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f5033e9b..02aa0d53 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -8,7 +8,7 @@ # {ChannelsLastOp}_vX instead of hardcoding what versions are supported -class ChannelsLastCustomOpDict: +class ChannelsLastCustomOpDict(dict): def __init__(self): self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization} diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..c352238c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -270,8 +270,13 @@ def apply(self, model): # Attach to original node n.output[i] = outp_trans_in - # Modify domain + # Modify node domain n.domain = "qonnx.custom_op.channels_last" + opset_imports = model.get_opset_imports() + # Ensure channels_last domain is imported in model + if "qonnx.custom_op.channels_last" not in opset_imports: + onnx_opset = opset_imports[""] + model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset)) # Set modified flag graph_modified = True From c9811c5cb7aedea89d1712e00e4d4e64cfe9b1bc Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Sat, 4 Oct 2025 00:19:31 +0200 Subject: [PATCH 27/35] [QuantAvgPool2d] use preferred ONNX opset for exec_node() impl --- src/qonnx/custom_op/general/quantavgpool2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index c0e24071..00617dcf 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model class QuantAvgPool2d(CustomOp): @@ -132,7 +132,7 @@ def execute_node(self, context, graph): outputs=[outp], ) - opset_version = self.onnx_opset_version + opset_version = get_preferred_onnx_opset() opset_imports = [helper.make_opsetid("", opset_version)] onnx_kwargs = {"opset_imports": opset_imports} model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs) From 073985d9c6e93ed1273cb7f089cfeaedbd17a5da Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Sat, 4 Oct 2025 00:30:20 +0200 Subject: [PATCH 28/35] [ChanLast] implement __contains__ for op registration --- src/qonnx/custom_op/channels_last/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 02aa0d53..9ffd4e54 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -18,6 +18,10 @@ def __getitem__(self, key): return self._custom_ops[base_key] raise KeyError(f"Channels-last CustomOp '{key}' not found.") + def __contains__(self, key): + base_key = key.split("_v")[0] + return base_key in self._custom_ops + def keys(self): return self._custom_ops.keys() From d982e5f2d65a53c30faf34a0ffe52f40344fdcaa Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 16 Oct 2025 14:03:14 +0200 Subject: [PATCH 29/35] [CustomOp] use get_preferred_qonnx_opset as default --- src/qonnx/custom_op/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/base.py b/src/qonnx/custom_op/base.py index 775d9f95..9cb0de11 100644 --- a/src/qonnx/custom_op/base.py +++ b/src/qonnx/custom_op/base.py @@ -30,7 +30,7 @@ import onnx.numpy_helper as np_helper from abc import ABC, abstractmethod -from qonnx.util.basic import get_by_name, get_preferred_onnx_opset +from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset class CustomOp(ABC): @@ -38,7 +38,7 @@ class CustomOp(ABC): every custom node should have. Some as abstract methods, these have to be filled when writing a new custom op node.""" - def __init__(self, onnx_node, onnx_opset_version=get_preferred_onnx_opset()): + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): super().__init__() self.onnx_node = onnx_node self.onnx_opset_version = onnx_opset_version From 94cf223b41854b1608c7f087ac3505223d09f74f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 16 Oct 2025 14:03:39 +0200 Subject: [PATCH 30/35] [Registry] bugfix for getCustomOp inst opset version --- src/qonnx/custom_op/registry.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 442089c3..258e9ab0 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -42,6 +42,7 @@ def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): try: opset_module = importlib.import_module(domain) assert isinstance(opset_module.custom_op, dict), "custom_op dict not found in Python module %s" % domain + found_opset_version = None if onnx_opset_version is None: inst_wrapper = opset_module.custom_op[op_type] else: @@ -49,6 +50,7 @@ def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): if op_type_with_version in opset_module.custom_op: # priority: if it exists, load the versioned CustomOp wrapper inst_wrapper = opset_module.custom_op[op_type_with_version] + found_opset_version = onnx_opset_version else: # when the exact version match is not found # version handling: use highest available version smaller than requested version @@ -59,11 +61,12 @@ def getCustomOp(node, onnx_opset_version=None, brevitas_exception=True): if suitable_versions: highest_version = max(suitable_versions) inst_wrapper = opset_module.custom_op[f"{op_type}_v{highest_version}"] + found_opset_version = highest_version else: raise Exception( "Op %s version %s not found in custom opset %s" % (op_type, str(onnx_opset_version), domain) ) - inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) + inst = inst_wrapper(node, onnx_opset_version=found_opset_version) return inst except ModuleNotFoundError: raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) From 32c0b3c3faab728695575ffc3283233144aef22c Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 16 Oct 2025 14:04:22 +0200 Subject: [PATCH 31/35] [Test] extra opset v checks in test_customop_version --- tests/custom_op/test_customop_version.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py index 5364df61..3efbde24 100644 --- a/tests/custom_op/test_customop_version.py +++ b/tests/custom_op/test_customop_version.py @@ -119,11 +119,13 @@ def test_customop_version(): # fetching of op version inst = model.get_customop_wrapper(model.graph.node[0]) assert inst.get_nodeattr(f"v{ver}_attr") == ver + assert inst.onnx_opset_version == ver # explicitly specify onnx_opset_version in getCustomOp # note: new code should avoid calling getCustomOp directly like this # and instead use ModelWrapper.get_customop_wrapper inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) assert inst.get_nodeattr(f"v{ver}_attr") == ver + assert inst.onnx_opset_version == ver # unspecified version getCustomOp should default to v1 handler model = make_vertest_model(1, False) inst = getCustomOp(model.graph.node[0]) @@ -132,3 +134,4 @@ def test_customop_version(): model = make_vertest_model(3, False) inst = getCustomOp(model.graph.node[0], onnx_opset_version=4) assert isinstance(inst, VerTestOp_v3) + assert inst.onnx_opset_version == 3 From 82ed368e2d959a76002c3e8452bfb60166980abd Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 16 Oct 2025 14:39:29 +0200 Subject: [PATCH 32/35] [Trunc] add v1 and v2 versions of the op separately --- src/qonnx/custom_op/general/__init__.py | 8 ++- src/qonnx/custom_op/general/trunc.py | 90 ++++++++++++++++++++++--- 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index e125cbf8..c670c1a3 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -35,7 +35,7 @@ from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d -from qonnx.custom_op.general.trunc import Trunc +from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2 from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul custom_op = dict() @@ -49,7 +49,7 @@ custom_op["Im2Col"] = Im2Col custom_op["IntQuant"] = IntQuant custom_op["Quant"] = IntQuant -custom_op["Trunc"] = Trunc +custom_op["Trunc"] = Trunc_v1 custom_op["BipolarQuant"] = BipolarQuant custom_op["FloatQuant"] = FloatQuant @@ -62,6 +62,8 @@ custom_op["Im2Col_v1"] = Im2Col custom_op["IntQuant_v1"] = IntQuant custom_op["Quant_v1"] = IntQuant -custom_op["Trunc_v1"] = Trunc +custom_op["Trunc_v1"] = Trunc_v1 custom_op["BipolarQuant_v1"] = BipolarQuant custom_op["FloatQuant_v1"] = FloatQuant + +custom_op["Trunc_v2"] = Trunc_v2 diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 36e60b84..7681e16e 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -32,9 +32,10 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode +from qonnx.util.basic import get_preferred_qonnx_opset -def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): +def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR # Scaling @@ -65,18 +66,23 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_sca return y -class Trunc(CustomOp): - """Generic truncation operation for QONNX. Takes four inputs: - - input tensor to truncate - - the scale - - the zero-point - - the truncation scale +class Trunc_v2(CustomOp): + """Generic truncation operation for QONNX. Takes four inputs: + - input tensor to truncate + - the scale + - the zero-point + - the truncation scale - the truncation bit-width The output is a tensor of the same shape as the input tensor, with truncated values. """ + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): + super().__init__(onnx_node, onnx_opset_version) + # override any specified opset version, this instance is v2 + self.onnx_opset_version = 2 + def get_nodeattr_types(self): return { # The rounding mode, which is used for the trunc function @@ -107,7 +113,7 @@ def execute_node(self, context, graph): narrow = self.get_nodeattr("narrow") signed = self.get_nodeattr("signed") # calculate output - ret = trunc( + ret = trunc_v2( inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode ) # set context according to output name @@ -115,3 +121,71 @@ def execute_node(self, context, graph): def verify_node(self): pass + + +def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): + # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR + + # Scaling + y = inp_tensor / scale + y = y + zeropt + # Rounding + y = np.round(y) + # Truncate + trunc_bit_width = input_bit_width - output_bit_width + trunc_scale = 2.0**trunc_bit_width + y = y / trunc_scale + + # To int + rounding_fx = resolve_rounding_mode(rounding_mode) + y = rounding_fx(y) + + # Rescale + y = y - zeropt + y = y * scale + + return y + + +class Trunc_v1(CustomOp): + """Generic truncation operation for QONNX. Takes four inputs: + - input tensor to truncate + - the scale + - the zero-point + - the truncation bit-width + + The output is a tensor of the same shape as the input tensor, with truncated + values. + """ + + def get_nodeattr_types(self): + return { + # The rounding mode, which is used for the trunc function + "rounding_mode": ("s", True, "FLOOR"), + } + + def make_shape_compatible_op(self, model): + node = self.onnx_node + return helper.make_node("Identity", [node.input[0]], [node.output[0]]) + + def infer_node_datatype(self, model): + node = self.onnx_node + model.set_tensor_datatype(node.output[0], DataType["FLOAT32"]) + + def execute_node(self, context, graph): + node = self.onnx_node + # save inputs + inp_tensor = context[node.input[0]] + scale = context[node.input[1]] + zeropt = context[node.input[2]] + input_bit_width = context[node.input[3]] + output_bit_width = context[node.input[4]] + # save attributes + rounding_mode = self.get_nodeattr("rounding_mode") + # calculate output + ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) + # set context according to output name + context[node.output[0]] = ret + + def verify_node(self): + pass From e044c63f67826317b105a5f2fa97b0d61a0c5c68 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 16 Oct 2025 14:59:24 +0200 Subject: [PATCH 33/35] [Trunc] set onnx_opset_version=1 for v1 instance --- src/qonnx/custom_op/general/trunc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 7681e16e..10c7e992 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -158,6 +158,11 @@ class Trunc_v1(CustomOp): values. """ + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): + super().__init__(onnx_node, onnx_opset_version) + # override any specified opset version, this instance is v1 + self.onnx_opset_version = 1 + def get_nodeattr_types(self): return { # The rounding mode, which is used for the trunc function From 006499b2b2ed0c0ded1f81e7376ca65db83685f0 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Tue, 21 Oct 2025 12:04:53 +0200 Subject: [PATCH 34/35] [Docs] add versioning to all op docs, v2 and v1 for Trunc, overview --- ...bipolar_quant_op.md => bipolarquant_v1.md} | 2 +- .../{floatquant_op.md => floatquant_v1.md} | 2 +- .../{intquant_op.md => intquant_v1.md} | 4 +- docs/qonnx-custom-ops/overview.md | 13 ++ docs/qonnx-custom-ops/trunc_v1.md | 131 ++++++++++++++++++ .../{trunc_op.md => trunc_v2.md} | 0 6 files changed, 148 insertions(+), 4 deletions(-) rename docs/qonnx-custom-ops/{bipolar_quant_op.md => bipolarquant_v1.md} (94%) rename docs/qonnx-custom-ops/{floatquant_op.md => floatquant_v1.md} (98%) rename docs/qonnx-custom-ops/{intquant_op.md => intquant_v1.md} (97%) create mode 100644 docs/qonnx-custom-ops/overview.md create mode 100644 docs/qonnx-custom-ops/trunc_v1.md rename docs/qonnx-custom-ops/{trunc_op.md => trunc_v2.md} (100%) diff --git a/docs/qonnx-custom-ops/bipolar_quant_op.md b/docs/qonnx-custom-ops/bipolarquant_v1.md similarity index 94% rename from docs/qonnx-custom-ops/bipolar_quant_op.md rename to docs/qonnx-custom-ops/bipolarquant_v1.md index 3a70458e..03c0c01e 100644 --- a/docs/qonnx-custom-ops/bipolar_quant_op.md +++ b/docs/qonnx-custom-ops/bipolarquant_v1.md @@ -5,7 +5,7 @@ Additionally, takes one float as input, which define the scaling. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_v1.md similarity index 98% rename from docs/qonnx-custom-ops/floatquant_op.md rename to docs/qonnx-custom-ops/floatquant_v1.md index fc51b75f..4536194b 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_v1.md @@ -16,7 +16,7 @@ special (symbolic) values. This makes it nontrivial to infer the maximum represe #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/intquant_op.md b/docs/qonnx-custom-ops/intquant_v1.md similarity index 97% rename from docs/qonnx-custom-ops/intquant_op.md rename to docs/qonnx-custom-ops/intquant_v1.md index fb627efb..4d15c0ec 100644 --- a/docs/qonnx-custom-ops/intquant_op.md +++ b/docs/qonnx-custom-ops/intquant_v1.md @@ -9,11 +9,11 @@ rounding_mode defines how quantized values are rounded. Notes: * This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models. -* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists. +* This operator does not work for binary or bipolar quantization, for this purpose the simpler `BipolarQuant` node exists. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. #### Attributes diff --git a/docs/qonnx-custom-ops/overview.md b/docs/qonnx-custom-ops/overview.md new file mode 100644 index 00000000..dfb93c38 --- /dev/null +++ b/docs/qonnx-custom-ops/overview.md @@ -0,0 +1,13 @@ +## Operator Schemas + +This file lists the QONNX custom operators, similar to `Operators.md` for the ONNX standard. +It is manually updated, since QONNX custom operators are relatively few in number. + +### qonnx.custom_op.general + +|**Operator**|**Since version**|| +|-|-|-| +|BipolarQuant|1| +|FloatQuant|1| +|IntQuant|1| +|Trunc|2, 1| diff --git a/docs/qonnx-custom-ops/trunc_v1.md b/docs/qonnx-custom-ops/trunc_v1.md new file mode 100644 index 00000000..04b88443 --- /dev/null +++ b/docs/qonnx-custom-ops/trunc_v1.md @@ -0,0 +1,131 @@ +### **Trunc** + +Truncates the values of one input data (Tensor) at a specified bitwidth and produces one output data (Tensor). +Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization. +The attribute rounding_mode defines how truncated values are rounded. + +#### Version + +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. + +#### Attributes + +
+
rounding_mode : string (default is "FLOOR")
+
Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
+ +#### Inputs + +
+
X (differentiable) : tensor(float32)
+
input tensor to truncate
+
scale : float32
+
The scale factor
+
zeropt : float32
+
The zero-point
+
in_bitwidth : int32
+
The number of bits used at the input of the truncation
+
out_bitwidth : int32
+
The number of bits used at the output of the truncation
+
+ + +#### Outputs + +
+
Y (differentiable) : tensor(float32)
+
Output tensor
+
+ + +#### Examples +
+Trunc + +```python +from onnx import helper +import numpy as np + +# Define node settings and input +x = np.random.randn(100).astype(np.float32)*10. +scale = np.array(1.) +zeropt = np.array(0.) +in_bitwidth = np.array(10) +out_bitwidth = np.array(4) +rounding_mode = "ROUND" + +# Create node +node = helper.make_node( + 'Trunc', + domain='finn.custom_op.general', + inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'], + outputs=['y'], + rounding_mode=rounding_mode, +) + +# Execute the same settings with the reference implementation (trunc) +# See the sample implementation for more details on trunc. +output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode) + +# Execute node and compare +expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc') + +``` + +
+ + +#### Sample Implementation + +
+Trunc + +```python +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + +def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): + # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR + + # Scaling + y = inp_tensor / scale + y = y + zeropt + # Rounding + y = np.round(y) + # Truncate + trunc_bit_width = input_bit_width - output_bit_width + trunc_scale = 2.0 ** trunc_bit_width + y = y / trunc_scale + + # To int + rounding_fx = resolve_rounding_mode(rounding_mode) + y = rounding_fx(y) + + # Rescale + y = y - zeropt + y = y * scale + + return y + +def resolve_rounding_mode(mode_string): + """Resolve the rounding mode string of Quant and Trunc ops + to the corresponding numpy functions.""" + if mode_string == "ROUND": + return np.round + elif mode_string == "CEIL": + return np.ceil + elif mode_string == "FLOOR": + return np.floor + else: + raise ValueError(f"Could not resolve rounding mode called: {mode_string}") + +``` + +
diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_v2.md similarity index 100% rename from docs/qonnx-custom-ops/trunc_op.md rename to docs/qonnx-custom-ops/trunc_v2.md From 98b738362017d8222db6d4f4710d87366f406d55 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Tue, 21 Oct 2025 14:25:11 +0200 Subject: [PATCH 35/35] update README --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ecf6bd47..108b498c 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ QONNX example -QONNX (Quantized ONNX) introduces several custom operators -- [`IntQuant`](docs/qonnx-custom-ops/intquant_op.md), [`FloatQuant`](docs/qonnx-custom-ops/floatquant_op.md), [`BipolarQuant`](docs/qonnx-custom-ops/bipolar_quant_op.md), and [`Trunc`](docs/qonnx-custom-ops/trunc_op.md) -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables: +QONNX (Quantized ONNX) introduces several [custom operators](docs/qonnx-custom-ops/overview.md) -- `IntQuant`, `FloatQuant`, `BipolarQuant`, and `Trunc` -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables: * Representation of binary, ternary, 3-bit, 4-bit, 6-bit or any other integer/fixed-point quantization. * Representation of minifloat quantization with configurable exponent and mantissa bits. * Quantization is an operator itself, and can be applied to any parameter or layer input. @@ -29,9 +29,7 @@ This repository contains a set of Python utilities to work with QONNX models, in ### Operator definitions -* [Quant](docs/qonnx-custom-ops/quant_op.md) for 2-to-arbitrary-bit quantization, with scaling and zero-point -* [BipolarQuant](docs/qonnx-custom-ops/bipolar_quant_op.md) for 1-bit (bipolar) quantization, with scaling and zero-point -* [Trunc](docs/qonnx-custom-ops/trunc_op.md) for truncating to a specified number of bits, with scaling and zero-point +Please see the [custom operator overview](docs/qonnx-custom-ops/overview.md) table for more details. ### Installation