Skip to content

Commit 329ed94

Browse files
authored
Merge pull request #226 from LinusJungemann/refactor/type-hints
Add type hints to base classes to fix Pythons automatic type hinting for FINN/FINN+
2 parents 19b058a + 1828eb5 commit 329ed94

File tree

9 files changed

+506
-311
lines changed

9 files changed

+506
-311
lines changed

src/qonnx/core/datatype.py

Lines changed: 149 additions & 132 deletions
Large diffs are not rendered by default.

src/qonnx/core/modelwrapper.py

Lines changed: 230 additions & 102 deletions
Large diffs are not rendered by default.

src/qonnx/core/onnx_exec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,14 @@ def execute_node(node, context, graph, opset_version, return_full_exec_context=F
8787
outp = node.output[output_ind]
8888

8989
# retrieve the index of that name in node_outputs
90+
list_ind = None
9091
for i in range(len(node_outputs)):
9192
if outp == node_outputs[i].name:
9293
list_ind = i
9394

9495
# use that index to index output_list
96+
if list_ind is None:
97+
raise Exception("Output %s not found in node outputs." % outp)
9598
if output_list[list_ind].shape != context[outp].shape:
9699
warnings.warn(
97100
"""Output shapes disagree after node %s execution:

src/qonnx/custom_op/base.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,17 @@
3030
import onnx.helper as helper
3131
import onnx.numpy_helper as np_helper
3232
from abc import ABC, abstractmethod
33+
from collections.abc import Mapping
34+
from typing import TYPE_CHECKING, Sequence, cast
35+
36+
import numpy.typing as npt
37+
from onnx import NodeProto, GraphProto, TensorProto
3338

3439
from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset
3540

41+
if TYPE_CHECKING:
42+
from qonnx.core.modelwrapper import ModelWrapper
43+
3644

3745
class CustomOp(ABC):
3846
"""CustomOp class all custom op nodes are based on. Contains different functions
@@ -59,15 +67,21 @@ class IntQuant_v4(CustomOp):
5967
pass # Version 4, covers opset v4+
6068
"""
6169

62-
def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
70+
def __init__(
71+
self,
72+
onnx_node: NodeProto,
73+
onnx_opset_version: int = get_preferred_qonnx_opset(),
74+
) -> None:
6375
super().__init__()
64-
self.onnx_node = onnx_node
65-
self.onnx_opset_version = onnx_opset_version
76+
self.onnx_node: NodeProto = onnx_node
77+
self.onnx_opset_version: int = onnx_opset_version
6678

67-
def get_nodeattr_def(self, name):
79+
def get_nodeattr_def(
80+
self, name: str
81+
) -> tuple[str, bool, int | float | str | bool | npt.NDArray | list[str | int | float], set | None]:
6882
"""Return 4-tuple (dtype, required, default_val, allowed_values) for attribute
6983
with name. allowed_values will be None if not specified."""
70-
allowed_values = None
84+
allowed_values: set | None = None
7185
attrdef = self.get_nodeattr_types()[name]
7286
if len(attrdef) == 3:
7387
(dtype, req, def_val) = attrdef
@@ -79,11 +93,15 @@ def get_nodeattr_def(self, name):
7993
)
8094
return (dtype, req, def_val, allowed_values)
8195

82-
def get_nodeattr_allowed_values(self, name):
96+
def get_nodeattr_allowed_values(
97+
self, name: str
98+
) -> str | bool | int | float | npt.NDArray | list[str | int | float] | set | None:
8399
"Return set of allowed values for given attribute, None if not specified."
84100
return self.get_nodeattr_def(name)[3]
85101

86-
def get_nodeattr(self, name):
102+
def get_nodeattr(
103+
self, name: str
104+
) -> int | float | str | bool | npt.NDArray | list[str | int | float] | None:
87105
"""Get a node attribute by name. Data is stored inside the ONNX node's
88106
AttributeProto container. Attribute must be part of get_nodeattr_types.
89107
Default value is returned if attribute is not set."""
@@ -128,9 +146,13 @@ def get_nodeattr(self, name):
128146
# not set, return default value
129147
return def_val
130148
except KeyError:
131-
raise AttributeError("Op has no such attribute: " + name)
149+
raise AttributeError(
150+
f"{self.onnx_node.name} has no such attribute: " + name
151+
)
132152

133-
def set_nodeattr(self, name, value):
153+
def set_nodeattr(
154+
self, name: str, value: int | float | str | bool | npt.NDArray | list[str | int | float] | None
155+
) -> None:
134156
"""Set a node attribute by name. Data is stored inside the ONNX node's
135157
AttributeProto container. Attribute must be part of get_nodeattr_types."""
136158
try:
@@ -142,7 +164,7 @@ def set_nodeattr(self, name, value):
142164
% (str(name), str(value), str(allowed_values))
143165
)
144166
attr = get_by_name(self.onnx_node.attribute, name)
145-
167+
tensor_value : TensorProto | None = None
146168
# Verify value type matches dtype before setting/converting
147169
if dtype == "i":
148170
if not isinstance(value, int):
@@ -185,23 +207,25 @@ def set_nodeattr(self, name, value):
185207
f"Attribute {name} expects numpy array, got {type(value)}"
186208
)
187209
# Convert numpy array to TensorProto
188-
value = np_helper.from_array(value)
189-
210+
tensor_value = np_helper.from_array(cast(npt.NDArray, value))
190211
if attr is not None:
191212
# dtype indicates which ONNX Attribute member to use
192213
# (such as i, f, s...)
193214
if dtype == "s":
194215
# encode string attributes
195-
value = value.encode("utf-8")
196-
attr.__setattr__(dtype, value)
216+
val = cast(str, value).encode("utf-8")
217+
attr.__setattr__(dtype, val)
197218
elif dtype == "strings":
198-
attr.strings[:] = [x.encode("utf-8") for x in value]
219+
attr.strings[:] = [
220+
x.encode("utf-8") for x in cast(list[str], value)
221+
]
199222
elif dtype == "floats": # list of floats
200-
attr.floats[:] = value
223+
attr.floats[:] = cast(list[float], value)
201224
elif dtype == "ints": # list of integers
202-
attr.ints[:] = value
225+
attr.ints[:] = cast(list[int], value)
203226
elif dtype == "t": # single tensor
204-
attr.t.CopyFrom(value)
227+
assert tensor_value is not None
228+
attr.t.CopyFrom(tensor_value)
205229
elif dtype in ["tensors", "graphs", "sparse_tensors"]:
206230
# untested / unsupported attribute types
207231
# add testcases & appropriate getters before enabling
@@ -211,12 +235,13 @@ def set_nodeattr(self, name, value):
211235
attr.__setattr__(dtype, value)
212236
else:
213237
# not set, create and insert AttributeProto
214-
attr_proto = helper.make_attribute(name, value)
238+
attr_value = tensor_value if tensor_value is not None else value
239+
attr_proto = helper.make_attribute(name, attr_value)
215240
self.onnx_node.attribute.append(attr_proto)
216241
except KeyError:
217242
raise AttributeError("Op has no such attribute: " + name)
218243

219-
def make_const_shape_op(self, shape):
244+
def make_const_shape_op(self, shape: Sequence[int] | npt.NDArray) -> NodeProto:
220245
"""Return an ONNX node that generates the desired output shape for
221246
shape inference."""
222247
return helper.make_node(
@@ -230,7 +255,13 @@ def make_const_shape_op(self, shape):
230255
)
231256

232257
@abstractmethod
233-
def get_nodeattr_types(self):
258+
def get_nodeattr_types(
259+
self,
260+
) -> Mapping[
261+
str,
262+
tuple[str, bool, int | float | str | bool | npt.NDArray | list[str | int | float]]
263+
| tuple[str, bool, int | float | str | bool | npt.NDArray | list[str | int | float], set | None],
264+
]:
234265
"""Returns a dict of permitted attributes for node, where:
235266
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
236267
- dtype indicates which member of the ONNX AttributeProto
@@ -245,25 +276,25 @@ def get_nodeattr_types(self):
245276
pass
246277

247278
@abstractmethod
248-
def make_shape_compatible_op(self, model):
279+
def make_shape_compatible_op(self, model: "ModelWrapper") -> NodeProto:
249280
"""Returns a standard ONNX op which is compatible with this CustomOp
250281
for performing shape inference."""
251282
pass
252283

253284
@abstractmethod
254-
def infer_node_datatype(self, model):
285+
def infer_node_datatype(self, model: "ModelWrapper") -> None:
255286
"""Set the DataType annotations corresponding to the outputs of this
256287
node."""
257288
pass
258289

259290
@abstractmethod
260-
def execute_node(self, context, graph):
291+
def execute_node(self, context: dict[str, npt.NDArray], graph: GraphProto) -> None:
261292
"""Execute this CustomOp instance, given the execution context and
262293
ONNX graph."""
263294
pass
264295

265296
@abstractmethod
266-
def verify_node(self):
297+
def verify_node(self) -> None:
267298
"""Verifies that all attributes the node needs are there and
268299
that particular attributes are set correctly. Also checks if
269300
the number of inputs is equal to the expected number."""

src/qonnx/custom_op/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import warnings
3232
from threading import RLock
3333
from typing import Dict, List, Optional, Tuple, Type
34-
34+
from onnx import NodeProto
3535
from qonnx.custom_op.base import CustomOp
3636

3737
# Nested registry for O(1) lookups: domain -> op_type -> version -> CustomOp class
@@ -320,7 +320,7 @@ def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None:
320320
_OP_REGISTRY[domain][op_type][op_version] = op_class
321321

322322

323-
def getCustomOp(node, onnx_opset_version=None):
323+
def getCustomOp(node: NodeProto, onnx_opset_version: int | None = None) -> CustomOp:
324324
"""Get a custom op instance for an ONNX node.
325325
326326
Uses "since version" semantics: selects highest version <= requested opset.

src/qonnx/transformation/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,16 @@
4747
manually re-apply the transform.
4848
"""
4949

50+
from __future__ import annotations
51+
5052
import copy
5153
import multiprocessing as mp
5254
from abc import ABC, abstractmethod
55+
from typing import TYPE_CHECKING
56+
57+
if TYPE_CHECKING:
58+
from onnx import NodeProto
59+
from qonnx.core.modelwrapper import ModelWrapper
5360

5461
from qonnx.util.basic import get_num_default_workers
5562

@@ -58,11 +65,11 @@ class Transformation(ABC):
5865
"""Transformation class all transformations are based on. Contains only
5966
abstract method apply() every transformation has to fill."""
6067

61-
def __init__(self):
68+
def __init__(self) -> None:
6269
super().__init__()
6370

6471
@abstractmethod
65-
def apply(self, model):
72+
def apply(self, model: ModelWrapper) -> tuple[ModelWrapper, bool]:
6673
pass
6774

6875

@@ -83,7 +90,7 @@ class NodeLocalTransformation(Transformation):
8390
* (any other int>0): set number of parallel workers
8491
"""
8592

86-
def __init__(self, num_workers=None):
93+
def __init__(self, num_workers: int | None = None) -> None:
8794
super().__init__()
8895
if num_workers is None:
8996
self._num_workers = get_num_default_workers()
@@ -94,15 +101,15 @@ def __init__(self, num_workers=None):
94101
self._num_workers = mp.cpu_count()
95102

96103
@abstractmethod
97-
def applyNodeLocal(self, node):
104+
def applyNodeLocal(self, node) -> tuple[NodeProto, bool]:
98105
pass
99106

100-
def apply(self, model):
107+
def apply(self, model: ModelWrapper) -> tuple[ModelWrapper, bool]:
101108
# make a detached copy of the input model that applyNodeLocal
102109
# can use for read-only access
103110
self.ref_input_model = copy.deepcopy(model)
104111
# Remove old nodes from the current model
105-
old_nodes = []
112+
old_nodes: list[NodeProto] = []
106113
for i in range(len(model.graph.node)):
107114
old_nodes.append(model.graph.node.pop())
108115

0 commit comments

Comments
 (0)