Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e2f11b0
[trunc] Updated Trunc to match the new numerics / export from Brevitas
nickfraser Feb 11, 2025
e59177f
Update trunc_op description.
nickfraser Mar 13, 2025
b791c7b
Minor fixes.
nickfraser Mar 13, 2025
c611ae1
Improved formatting in RTD
nickfraser Mar 13, 2025
05cd37f
Feat (trunc): Switch output_scale, output_zero_point to be inputs ins…
nickfraser Mar 24, 2025
a30aaf1
[trunc] Removed redundant output zero-point input
nickfraser Mar 27, 2025
5ecc349
[trunc] Update docstring
nickfraser Mar 27, 2025
7e9f49d
[docs] Updated definition of the trunc operator
nickfraser Mar 27, 2025
7dfc4b8
[Lint] rerun linter, fix errors
maltanar Sep 25, 2025
7456919
[Core] add get_opset_imports utility fxn to ModelWrapper
maltanar Sep 25, 2025
89396cd
[Core] return dict from ModelWrapper.get_opset_imports
maltanar Sep 25, 2025
db2994f
[Core] add versioned op to getCustomOp with fallback to old style
maltanar Sep 25, 2025
8a2db22
[Core] inrtoduce ModelWrapper.get_customop_wrapper
maltanar Sep 25, 2025
402a580
[Test] add basic unit tests for versioned custom op fetching
maltanar Sep 25, 2025
ad80561
Merge branch 'main' into feature/op_version
maltanar Oct 2, 2025
407fb13
[Test] extend test_customop_version for default v handler
maltanar Oct 2, 2025
feac9f0
[Core] opset ver. fallback for ModelWrapper.get_customop_wrapper
maltanar Oct 2, 2025
89eea4c
[Core] getCustomOp: default v to None, fetch highest available v.
maltanar Oct 2, 2025
ec517b5
[Test] cover newly added opset ver behavior in test_customop_version
maltanar Oct 2, 2025
7406dcf
Merge branch 'main' into feature/op_version
maltanar Oct 2, 2025
aeeff58
[Core, Util] distinguish preferred onnx opset from qonnx opset
maltanar Oct 2, 2025
5801504
[Core] respect selected opsets during execution
maltanar Oct 2, 2025
35b8b12
[CustomOp] alias all qonnx.custom_op.general as v1
maltanar Oct 2, 2025
d190a69
[ChanLast] alias existing channels_last ops as v1
maltanar Oct 2, 2025
5f58f49
[Test] add opsets for test_custom_onnx_exec
maltanar Oct 2, 2025
db0b15a
[ChanLast] emulate op ver agnostic dict for channels last ops
maltanar Oct 3, 2025
83c53ae
[Core] use isinstance instead of type check for custom_op
maltanar Oct 3, 2025
6bfc2a1
[ChanLast] derive fake custom_op from dict, ensure domain import
maltanar Oct 3, 2025
c9811c5
[QuantAvgPool2d] use preferred ONNX opset for exec_node() impl
maltanar Oct 3, 2025
073985d
[ChanLast] implement __contains__ for op registration
maltanar Oct 3, 2025
0260d98
Merge branch 'main' into feature/op_version
maltanar Oct 16, 2025
d982e5f
[CustomOp] use get_preferred_qonnx_opset as default
maltanar Oct 16, 2025
94cf223
[Registry] bugfix for getCustomOp inst opset version
maltanar Oct 16, 2025
32c0b3c
[Test] extra opset v checks in test_customop_version
maltanar Oct 16, 2025
efdc74a
Merge branch 'fix/trunc_avg_pool' of https://github.com/nickfraser/qo…
maltanar Oct 16, 2025
82ed368
[Trunc] add v1 and v2 versions of the op separately
maltanar Oct 16, 2025
e044c63
[Trunc] set onnx_opset_version=1 for v1 instance
maltanar Oct 16, 2025
006499b
[Docs] add versioning to all op docs, v2 and v1 for Trunc, overview
maltanar Oct 21, 2025
98b7383
update README
maltanar Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<img align="left" src="https://xilinx.github.io/finn/img/TFC_1W2A.onnx.png" alt="QONNX example" style="margin-right: 20px" width="200"/>


QONNX (Quantized ONNX) introduces several custom operators -- [`IntQuant`](docs/qonnx-custom-ops/intquant_op.md), [`FloatQuant`](docs/qonnx-custom-ops/floatquant_op.md), [`BipolarQuant`](docs/qonnx-custom-ops/bipolar_quant_op.md), and [`Trunc`](docs/qonnx-custom-ops/trunc_op.md) -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
QONNX (Quantized ONNX) introduces several [custom operators](docs/qonnx-custom-ops/overview.md) -- `IntQuant`, `FloatQuant`, `BipolarQuant`, and `Trunc` -- in order to represent arbitrary-precision integer and minifloat quantization in ONNX. This enables:
* Representation of binary, ternary, 3-bit, 4-bit, 6-bit or any other integer/fixed-point quantization.
* Representation of minifloat quantization with configurable exponent and mantissa bits.
* Quantization is an operator itself, and can be applied to any parameter or layer input.
Expand All @@ -29,9 +29,7 @@ This repository contains a set of Python utilities to work with QONNX models, in

### Operator definitions

* [Quant](docs/qonnx-custom-ops/quant_op.md) for 2-to-arbitrary-bit quantization, with scaling and zero-point
* [BipolarQuant](docs/qonnx-custom-ops/bipolar_quant_op.md) for 1-bit (bipolar) quantization, with scaling and zero-point
* [Trunc](docs/qonnx-custom-ops/trunc_op.md) for truncating to a specified number of bits, with scaling and zero-point
Please see the [custom operator overview](docs/qonnx-custom-ops/overview.md) table for more details.

### Installation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Additionally, takes one float as input, which define the scaling.

#### Version

This operator is not part of the ONNX standard and is not currently versioned.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.

#### Attributes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ special (symbolic) values. This makes it nontrivial to infer the maximum represe

#### Version

This operator is not part of the ONNX standard and is not currently versioned.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.

#### Attributes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ rounding_mode defines how quantized values are rounded.

Notes:
* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
* This operator does not work for binary or bipolar quantization, for this purpose the simpler `BipolarQuant` node exists.

#### Version

This operator is not part of the ONNX standard and is not currently versioned.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.

#### Attributes

Expand Down
13 changes: 13 additions & 0 deletions docs/qonnx-custom-ops/overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Operator Schemas

This file lists the QONNX custom operators, similar to `Operators.md` for the ONNX standard.
It is manually updated, since QONNX custom operators are relatively few in number.

### qonnx.custom_op.general

|**Operator**|**Since version**||
|-|-|-|
|<a href="bipolarquant_v1.md">BipolarQuant</a>|<a href="bipolarquant_v1.md">1</a>|
|<a href="floatquant_v1.md">FloatQuant</a>|<a href="floatquant_v1.md">1</a>|
|<a href="intquant_v1.md">IntQuant</a>|<a href="intquant_v1.md">1</a>|
|<a href="trunc_v2.md">Trunc</a>|<a href="trunc_v2.md">2</a>, <a href="trunc_v1.md">1</a>|
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The attribute rounding_mode defines how truncated values are rounded.

#### Version

This operator is not part of the ONNX standard and is not currently versioned.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1.

#### Attributes

Expand Down
144 changes: 144 additions & 0 deletions docs/qonnx-custom-ops/trunc_v2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
### <a name="Trunc"></a><a name="abs">**Trunc**</a>

Truncates the values of one input data (Tensor<T>) at a specified bitwidth and produces one output data (Tensor<T>).
Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization.
The attribute rounding_mode defines how truncated values are rounded.

#### Version

This operator is not part of the ONNX standard.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.

#### Attributes

<dl>
<dt><tt>rounding_mode</tt> : string (default is "FLOOR")</dt>
<dd>Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
<dt><tt>signed</tt> : int (default is 1)</dt>
<dd>Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].</dd>
<dt><tt>narrow</tt> : int (default is 0)</dt>
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : tensor(float32)</dt>
<dd>input tensor to truncate</dd>
<dt><tt>scale</tt> : float32</dt>
<dd>The scale factor at the input of the truncation</dd>
<dt><tt>zeropt</tt> : float32</dt>
<dd>The zero-point at the input of the truncation</dd>
<dt><tt>in_bitwidth</tt> : int32</dt>
<dd>The number of bits used at the input of the truncation</dd>
<dt><tt>out_scale</tt> : float32</dt>
<dd>The scale factor of the output of the truncation</dd>
<dt><tt>out_bitwidth</tt> : int32</dt>
<dd>The number of bits used at the output of the truncation</dd>
</dl>


#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : tensor(float32)</dt>
<dd>Output tensor</dd>
</dl>


#### Examples
<details>
<summary>Trunc</summary>

```python
from onnx import helper
import numpy as np

# Define node settings and input
x = np.random.randn(100).astype(np.float32)*10.
scale = np.array(1.)
zeropt = np.array(0.)
in_bitwidth = np.array(10)
out_bitwidth = np.array(4)
rounding_mode = "ROUND"

# Create node
node = helper.make_node(
'Trunc',
domain='finn.custom_op.general',
inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'],
outputs=['y'],
rounding_mode=rounding_mode,
)

# Execute the same settings with the reference implementation (trunc)
# See the sample implementation for more details on trunc.
output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode)

# Execute node and compare
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc')

```

</details>


#### Sample Implementation

<details>
<summary>Trunc</summary>

```python
# SPDX-License-Identifier: Apache-2.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np

def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):

# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
# Rescale
trunc_scale = 2 ** np.round(
np.log2(output_scale / scale)
) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale

# Clamping
min_int_val = min_int(signed, narrow, output_bit_width)
max_int_val = max_int(signed, narrow, output_bit_width)
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
# To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)

# Rescale
output_zeropt = zeropt / trunc_scale # Rescale zero-point
y = y - output_zeropt
y = y * output_scale

return y

def resolve_rounding_mode(mode_string):
"""Resolve the rounding mode string of Quant and Trunc ops
to the corresponding numpy functions."""
if mode_string == "ROUND":
return np.round
elif mode_string == "CEIL":
return np.ceil
elif mode_string == "FLOOR":
return np.floor
else:
raise ValueError(f"Could not resolve rounding mode called: {mode_string}")

```

</details>
3 changes: 1 addition & 2 deletions src/qonnx/core/execute_custom_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import qonnx.custom_op.registry as registry
from qonnx.util.basic import get_preferred_onnx_opset


def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()):
def execute_custom_node(node, context, graph, onnx_opset_version):
"""Call custom implementation to execute a single custom node.
Input/output provided via context."""
op_type = node.op_type
Expand Down
22 changes: 22 additions & 0 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import qonnx.util.basic as util
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
Expand Down Expand Up @@ -737,3 +738,24 @@ def set_tensor_sparsity(self, tensor_name, sparsity_dict):
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)

def get_opset_imports(self):
"""Returns a list of imported opsets as a {domain, version} dictionary."""
return {opset.domain: opset.version for opset in self._model_proto.opset_import}

def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferred_qonnx_opset()):
"""Return CustomOp instance for given node, respecting the
imported opset version in the model protobuf. If the node's domain
is not found in the model's opset imports, fallback_customop_version
will be used."""
opset_imports = self.get_opset_imports()
try:
opset_import = opset_imports[node.domain]
return getCustomOp(node, onnx_opset_version=opset_import)
except KeyError:
# domain not found in imports, use fallback version
warnings.warn(
f"Domain {node.domain} not found in model opset imports, "
f"using fallback_customop_version={fallback_customop_version}"
)
return getCustomOp(node, onnx_opset_version=fallback_customop_version)
12 changes: 8 additions & 4 deletions src/qonnx/core/onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
import qonnx.analysis.topology as ta
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.util.basic import (
get_preferred_onnx_opset,
get_preferred_qonnx_opset,
get_sanitize_quant_tensors,
is_finn_op,
qonnx_make_model,
sanitize_quant_values,
)


def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()):
def execute_node(node, context, graph, opset_version, return_full_exec_context=False):
"""Executes a single node by using onnxruntime or with a custom function.

Input/output provided via context."""
Expand Down Expand Up @@ -158,7 +158,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
model_exec_mode = model.get_metadata_prop("exec_mode")
if (model_exec_mode is None) or (model_exec_mode == ""):
# extract opset version for node-by-node execution
opset_version = model.model.opset_import[0].version
opset_imports = model.get_opset_imports()
# execute the model node by node
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
Expand All @@ -176,7 +176,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N
if get_sanitize_quant_tensors() != 0:
# round input values to match quantization annotation
execution_context = sanitize_quant_values(model, node.input, execution_context)
execute_node(node, execution_context, graph, return_full_exec_context, opset_version)
if node.domain in opset_imports:
opset_version = opset_imports[node.domain]
else:
opset_version = get_preferred_qonnx_opset()
execute_node(node, execution_context, graph, opset_version, return_full_exec_context)
if get_sanitize_quant_tensors() != 0:
# round output values to quantization annotation
execution_context = sanitize_quant_values(model, node.output, execution_context)
Expand Down
4 changes: 2 additions & 2 deletions src/qonnx/custom_op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
import onnx.numpy_helper as np_helper
from abc import ABC, abstractmethod

from qonnx.util.basic import get_by_name, get_preferred_onnx_opset
from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset


class CustomOp(ABC):
"""CustomOp class all custom op nodes are based on. Contains different functions
every custom node should have. Some as abstract methods, these have to be
filled when writing a new custom op node."""

def __init__(self, onnx_node, onnx_opset_version=get_preferred_onnx_opset()):
def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
super().__init__()
self.onnx_node = onnx_node
self.onnx_opset_version = onnx_opset_version
Expand Down
28 changes: 24 additions & 4 deletions src/qonnx/custom_op/channels_last/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,28 @@
from qonnx.custom_op.channels_last.conv import Conv
from qonnx.custom_op.channels_last.max_pool import MaxPool

custom_op = dict()
# channels-last ops are defined by the underlying ONNX standard op
# thus, we can define them for any version of the original op
# so we emulate a custom op dictionary that mimics the support for any
# {ChannelsLastOp}_vX instead of hardcoding what versions are supported

custom_op["Conv"] = Conv
custom_op["MaxPool"] = MaxPool
custom_op["BatchNormalization"] = BatchNormalization

class ChannelsLastCustomOpDict(dict):
def __init__(self):
self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization}

def __getitem__(self, key):
base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13)
if base_key in self._custom_ops:
return self._custom_ops[base_key]
raise KeyError(f"Channels-last CustomOp '{key}' not found.")

def __contains__(self, key):
base_key = key.split("_v")[0]
return base_key in self._custom_ops

def keys(self):
return self._custom_ops.keys()


custom_op = ChannelsLastCustomOpDict()
19 changes: 17 additions & 2 deletions src/qonnx/custom_op/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from qonnx.custom_op.general.multithreshold import MultiThreshold
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
from qonnx.custom_op.general.trunc import Trunc
from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul

custom_op = dict()
Expand All @@ -49,6 +49,21 @@
custom_op["Im2Col"] = Im2Col
custom_op["IntQuant"] = IntQuant
custom_op["Quant"] = IntQuant
custom_op["Trunc"] = Trunc
custom_op["Trunc"] = Trunc_v1
custom_op["BipolarQuant"] = BipolarQuant
custom_op["FloatQuant"] = FloatQuant

custom_op["DebugMarker_v1"] = DebugMarker
custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d
custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC
custom_op["GenericPartition_v1"] = GenericPartition
custom_op["MultiThreshold_v1"] = MultiThreshold
custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul
custom_op["Im2Col_v1"] = Im2Col
custom_op["IntQuant_v1"] = IntQuant
custom_op["Quant_v1"] = IntQuant
custom_op["Trunc_v1"] = Trunc_v1
custom_op["BipolarQuant_v1"] = BipolarQuant
custom_op["FloatQuant_v1"] = FloatQuant

custom_op["Trunc_v2"] = Trunc_v2
Loading
Loading