Skip to content
281 changes: 149 additions & 132 deletions src/qonnx/core/datatype.py

Large diffs are not rendered by default.

332 changes: 230 additions & 102 deletions src/qonnx/core/modelwrapper.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/qonnx/core/onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,14 @@ def execute_node(node, context, graph, opset_version, return_full_exec_context=F
outp = node.output[output_ind]

# retrieve the index of that name in node_outputs
list_ind = None
for i in range(len(node_outputs)):
if outp == node_outputs[i].name:
list_ind = i

# use that index to index output_list
if list_ind is None:
raise Exception("Output %s not found in node outputs." % outp)
if output_list[list_ind].shape != context[outp].shape:
warnings.warn(
"""Output shapes disagree after node %s execution:
Expand Down
81 changes: 56 additions & 25 deletions src/qonnx/custom_op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,17 @@
import onnx.helper as helper
import onnx.numpy_helper as np_helper
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Sequence, cast

import numpy.typing as npt
from onnx import NodeProto, GraphProto, TensorProto

from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset

if TYPE_CHECKING:
from qonnx.core.modelwrapper import ModelWrapper


class CustomOp(ABC):
"""CustomOp class all custom op nodes are based on. Contains different functions
Expand All @@ -59,15 +67,21 @@ class IntQuant_v4(CustomOp):
pass # Version 4, covers opset v4+
"""

def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
def __init__(
self,
onnx_node: NodeProto,
onnx_opset_version: int = get_preferred_qonnx_opset(),
) -> None:
super().__init__()
self.onnx_node = onnx_node
self.onnx_opset_version = onnx_opset_version
self.onnx_node: NodeProto = onnx_node
self.onnx_opset_version: int = onnx_opset_version

def get_nodeattr_def(self, name):
def get_nodeattr_def(
self, name: str
) -> tuple[str, bool, int | float | str | bool | npt.NDArray | list[str | int | float], set | None]:
"""Return 4-tuple (dtype, required, default_val, allowed_values) for attribute
with name. allowed_values will be None if not specified."""
allowed_values = None
allowed_values: set | None = None
attrdef = self.get_nodeattr_types()[name]
if len(attrdef) == 3:
(dtype, req, def_val) = attrdef
Expand All @@ -79,11 +93,15 @@ def get_nodeattr_def(self, name):
)
return (dtype, req, def_val, allowed_values)

def get_nodeattr_allowed_values(self, name):
def get_nodeattr_allowed_values(
self, name: str
) -> str | bool | int | float | npt.NDArray | list[str | int | float] | set | None:
"Return set of allowed values for given attribute, None if not specified."
return self.get_nodeattr_def(name)[3]

def get_nodeattr(self, name):
def get_nodeattr(
self, name: str
) -> int | float | str | bool | npt.NDArray | list[str | int | float] | None:
"""Get a node attribute by name. Data is stored inside the ONNX node's
AttributeProto container. Attribute must be part of get_nodeattr_types.
Default value is returned if attribute is not set."""
Expand Down Expand Up @@ -128,9 +146,13 @@ def get_nodeattr(self, name):
# not set, return default value
return def_val
except KeyError:
raise AttributeError("Op has no such attribute: " + name)
raise AttributeError(
f"{self.onnx_node.name} has no such attribute: " + name
)

def set_nodeattr(self, name, value):
def set_nodeattr(
self, name: str, value: int | float | str | bool | npt.NDArray | list[str | int | float] | None
) -> None:
"""Set a node attribute by name. Data is stored inside the ONNX node's
AttributeProto container. Attribute must be part of get_nodeattr_types."""
try:
Expand All @@ -142,7 +164,7 @@ def set_nodeattr(self, name, value):
% (str(name), str(value), str(allowed_values))
)
attr = get_by_name(self.onnx_node.attribute, name)

tensor_value : TensorProto | None = None
# Verify value type matches dtype before setting/converting
if dtype == "i":
if not isinstance(value, int):
Expand Down Expand Up @@ -185,23 +207,25 @@ def set_nodeattr(self, name, value):
f"Attribute {name} expects numpy array, got {type(value)}"
)
# Convert numpy array to TensorProto
value = np_helper.from_array(value)

tensor_value = np_helper.from_array(cast(npt.NDArray, value))
if attr is not None:
# dtype indicates which ONNX Attribute member to use
# (such as i, f, s...)
if dtype == "s":
# encode string attributes
value = value.encode("utf-8")
attr.__setattr__(dtype, value)
val = cast(str, value).encode("utf-8")
attr.__setattr__(dtype, val)
elif dtype == "strings":
attr.strings[:] = [x.encode("utf-8") for x in value]
attr.strings[:] = [
x.encode("utf-8") for x in cast(list[str], value)
]
elif dtype == "floats": # list of floats
attr.floats[:] = value
attr.floats[:] = cast(list[float], value)
elif dtype == "ints": # list of integers
attr.ints[:] = value
attr.ints[:] = cast(list[int], value)
elif dtype == "t": # single tensor
attr.t.CopyFrom(value)
assert tensor_value is not None
attr.t.CopyFrom(tensor_value)
elif dtype in ["tensors", "graphs", "sparse_tensors"]:
# untested / unsupported attribute types
# add testcases & appropriate getters before enabling
Expand All @@ -211,12 +235,13 @@ def set_nodeattr(self, name, value):
attr.__setattr__(dtype, value)
else:
# not set, create and insert AttributeProto
attr_proto = helper.make_attribute(name, value)
attr_value = tensor_value if tensor_value is not None else value
attr_proto = helper.make_attribute(name, attr_value)
self.onnx_node.attribute.append(attr_proto)
except KeyError:
raise AttributeError("Op has no such attribute: " + name)

def make_const_shape_op(self, shape):
def make_const_shape_op(self, shape: Sequence[int] | npt.NDArray) -> NodeProto:
"""Return an ONNX node that generates the desired output shape for
shape inference."""
return helper.make_node(
Expand All @@ -230,7 +255,13 @@ def make_const_shape_op(self, shape):
)

@abstractmethod
def get_nodeattr_types(self):
def get_nodeattr_types(
self,
) -> Mapping[
str,
tuple[str, bool, int | float | str | bool | npt.NDArray | list[str | int | float]]
| tuple[str, bool, int | float | str | bool | npt.NDArray | list[str | int | float], set | None],
]:
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
- dtype indicates which member of the ONNX AttributeProto
Expand All @@ -245,25 +276,25 @@ def get_nodeattr_types(self):
pass

@abstractmethod
def make_shape_compatible_op(self, model):
def make_shape_compatible_op(self, model: "ModelWrapper") -> NodeProto:
"""Returns a standard ONNX op which is compatible with this CustomOp
for performing shape inference."""
pass

@abstractmethod
def infer_node_datatype(self, model):
def infer_node_datatype(self, model: "ModelWrapper") -> None:
"""Set the DataType annotations corresponding to the outputs of this
node."""
pass

@abstractmethod
def execute_node(self, context, graph):
def execute_node(self, context: dict[str, npt.NDArray], graph: GraphProto) -> None:
"""Execute this CustomOp instance, given the execution context and
ONNX graph."""
pass

@abstractmethod
def verify_node(self):
def verify_node(self) -> None:
"""Verifies that all attributes the node needs are there and
that particular attributes are set correctly. Also checks if
the number of inputs is equal to the expected number."""
Expand Down
4 changes: 2 additions & 2 deletions src/qonnx/custom_op/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import warnings
from threading import RLock
from typing import Dict, List, Optional, Tuple, Type

from onnx import NodeProto
from qonnx.custom_op.base import CustomOp

# Nested registry for O(1) lookups: domain -> op_type -> version -> CustomOp class
Expand Down Expand Up @@ -320,7 +320,7 @@ def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None:
_OP_REGISTRY[domain][op_type][op_version] = op_class


def getCustomOp(node, onnx_opset_version=None):
def getCustomOp(node: NodeProto, onnx_opset_version: int | None = None) -> CustomOp:
"""Get a custom op instance for an ONNX node.

Uses "since version" semantics: selects highest version <= requested opset.
Expand Down
19 changes: 13 additions & 6 deletions src/qonnx/transformation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,16 @@
manually re-apply the transform.
"""

from __future__ import annotations

import copy
import multiprocessing as mp
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from onnx import NodeProto
from qonnx.core.modelwrapper import ModelWrapper

from qonnx.util.basic import get_num_default_workers

Expand All @@ -58,11 +65,11 @@ class Transformation(ABC):
"""Transformation class all transformations are based on. Contains only
abstract method apply() every transformation has to fill."""

def __init__(self):
def __init__(self) -> None:
super().__init__()

@abstractmethod
def apply(self, model):
def apply(self, model: ModelWrapper) -> tuple[ModelWrapper, bool]:
pass


Expand All @@ -83,7 +90,7 @@ class NodeLocalTransformation(Transformation):
* (any other int>0): set number of parallel workers
"""

def __init__(self, num_workers=None):
def __init__(self, num_workers: int | None = None) -> None:
super().__init__()
if num_workers is None:
self._num_workers = get_num_default_workers()
Expand All @@ -94,15 +101,15 @@ def __init__(self, num_workers=None):
self._num_workers = mp.cpu_count()

@abstractmethod
def applyNodeLocal(self, node):
def applyNodeLocal(self, node) -> tuple[NodeProto, bool]:
pass

def apply(self, model):
def apply(self, model: ModelWrapper) -> tuple[ModelWrapper, bool]:
# make a detached copy of the input model that applyNodeLocal
# can use for read-only access
self.ref_input_model = copy.deepcopy(model)
# Remove old nodes from the current model
old_nodes = []
old_nodes: list[NodeProto] = []
for i in range(len(model.graph.node)):
old_nodes.append(model.graph.node.pop())

Expand Down
Loading