3030import onnx .helper as helper
3131import onnx .numpy_helper as np_helper
3232from abc import ABC , abstractmethod
33+ from collections .abc import Mapping
34+ from typing import TYPE_CHECKING , Sequence , cast
35+
36+ import numpy .typing as npt
37+ from onnx import NodeProto , GraphProto , TensorProto
3338
3439from qonnx .util .basic import get_by_name , get_preferred_qonnx_opset
3540
41+ if TYPE_CHECKING :
42+ from qonnx .core .modelwrapper import ModelWrapper
43+
3644
3745class CustomOp (ABC ):
3846 """CustomOp class all custom op nodes are based on. Contains different functions
@@ -59,15 +67,21 @@ class IntQuant_v4(CustomOp):
5967 pass # Version 4, covers opset v4+
6068 """
6169
62- def __init__ (self , onnx_node , onnx_opset_version = get_preferred_qonnx_opset ()):
70+ def __init__ (
71+ self ,
72+ onnx_node : NodeProto ,
73+ onnx_opset_version : int = get_preferred_qonnx_opset (),
74+ ) -> None :
6375 super ().__init__ ()
64- self .onnx_node = onnx_node
65- self .onnx_opset_version = onnx_opset_version
76+ self .onnx_node : NodeProto = onnx_node
77+ self .onnx_opset_version : int = onnx_opset_version
6678
67- def get_nodeattr_def (self , name ):
79+ def get_nodeattr_def (
80+ self , name : str
81+ ) -> tuple [str , bool , int | float | str | bool | npt .NDArray | list [str | int | float ], set | None ]:
6882 """Return 4-tuple (dtype, required, default_val, allowed_values) for attribute
6983 with name. allowed_values will be None if not specified."""
70- allowed_values = None
84+ allowed_values : set | None = None
7185 attrdef = self .get_nodeattr_types ()[name ]
7286 if len (attrdef ) == 3 :
7387 (dtype , req , def_val ) = attrdef
@@ -79,11 +93,15 @@ def get_nodeattr_def(self, name):
7993 )
8094 return (dtype , req , def_val , allowed_values )
8195
82- def get_nodeattr_allowed_values (self , name ):
96+ def get_nodeattr_allowed_values (
97+ self , name : str
98+ ) -> str | bool | int | float | npt .NDArray | list [str | int | float ] | set | None :
8399 "Return set of allowed values for given attribute, None if not specified."
84100 return self .get_nodeattr_def (name )[3 ]
85101
86- def get_nodeattr (self , name ):
102+ def get_nodeattr (
103+ self , name : str
104+ ) -> int | float | str | bool | npt .NDArray | list [str | int | float ] | None :
87105 """Get a node attribute by name. Data is stored inside the ONNX node's
88106 AttributeProto container. Attribute must be part of get_nodeattr_types.
89107 Default value is returned if attribute is not set."""
@@ -128,9 +146,13 @@ def get_nodeattr(self, name):
128146 # not set, return default value
129147 return def_val
130148 except KeyError :
131- raise AttributeError ("Op has no such attribute: " + name )
149+ raise AttributeError (
150+ f"{ self .onnx_node .name } has no such attribute: " + name
151+ )
132152
133- def set_nodeattr (self , name , value ):
153+ def set_nodeattr (
154+ self , name : str , value : int | float | str | bool | npt .NDArray | list [str | int | float ] | None
155+ ) -> None :
134156 """Set a node attribute by name. Data is stored inside the ONNX node's
135157 AttributeProto container. Attribute must be part of get_nodeattr_types."""
136158 try :
@@ -142,7 +164,7 @@ def set_nodeattr(self, name, value):
142164 % (str (name ), str (value ), str (allowed_values ))
143165 )
144166 attr = get_by_name (self .onnx_node .attribute , name )
145-
167+ tensor_value : TensorProto | None = None
146168 # Verify value type matches dtype before setting/converting
147169 if dtype == "i" :
148170 if not isinstance (value , int ):
@@ -185,23 +207,25 @@ def set_nodeattr(self, name, value):
185207 f"Attribute { name } expects numpy array, got { type (value )} "
186208 )
187209 # Convert numpy array to TensorProto
188- value = np_helper .from_array (value )
189-
210+ tensor_value = np_helper .from_array (cast (npt .NDArray , value ))
190211 if attr is not None :
191212 # dtype indicates which ONNX Attribute member to use
192213 # (such as i, f, s...)
193214 if dtype == "s" :
194215 # encode string attributes
195- value = value .encode ("utf-8" )
196- attr .__setattr__ (dtype , value )
216+ val = cast ( str , value ) .encode ("utf-8" )
217+ attr .__setattr__ (dtype , val )
197218 elif dtype == "strings" :
198- attr .strings [:] = [x .encode ("utf-8" ) for x in value ]
219+ attr .strings [:] = [
220+ x .encode ("utf-8" ) for x in cast (list [str ], value )
221+ ]
199222 elif dtype == "floats" : # list of floats
200- attr .floats [:] = value
223+ attr .floats [:] = cast ( list [ float ], value )
201224 elif dtype == "ints" : # list of integers
202- attr .ints [:] = value
225+ attr .ints [:] = cast ( list [ int ], value )
203226 elif dtype == "t" : # single tensor
204- attr .t .CopyFrom (value )
227+ assert tensor_value is not None
228+ attr .t .CopyFrom (tensor_value )
205229 elif dtype in ["tensors" , "graphs" , "sparse_tensors" ]:
206230 # untested / unsupported attribute types
207231 # add testcases & appropriate getters before enabling
@@ -211,12 +235,13 @@ def set_nodeattr(self, name, value):
211235 attr .__setattr__ (dtype , value )
212236 else :
213237 # not set, create and insert AttributeProto
214- attr_proto = helper .make_attribute (name , value )
238+ attr_value = tensor_value if tensor_value is not None else value
239+ attr_proto = helper .make_attribute (name , attr_value )
215240 self .onnx_node .attribute .append (attr_proto )
216241 except KeyError :
217242 raise AttributeError ("Op has no such attribute: " + name )
218243
219- def make_const_shape_op (self , shape ) :
244+ def make_const_shape_op (self , shape : Sequence [ int ] | npt . NDArray ) -> NodeProto :
220245 """Return an ONNX node that generates the desired output shape for
221246 shape inference."""
222247 return helper .make_node (
@@ -230,7 +255,13 @@ def make_const_shape_op(self, shape):
230255 )
231256
232257 @abstractmethod
233- def get_nodeattr_types (self ):
258+ def get_nodeattr_types (
259+ self ,
260+ ) -> Mapping [
261+ str ,
262+ tuple [str , bool , int | float | str | bool | npt .NDArray | list [str | int | float ]]
263+ | tuple [str , bool , int | float | str | bool | npt .NDArray | list [str | int | float ], set | None ],
264+ ]:
234265 """Returns a dict of permitted attributes for node, where:
235266 ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
236267 - dtype indicates which member of the ONNX AttributeProto
@@ -245,25 +276,25 @@ def get_nodeattr_types(self):
245276 pass
246277
247278 @abstractmethod
248- def make_shape_compatible_op (self , model ) :
279+ def make_shape_compatible_op (self , model : "ModelWrapper" ) -> NodeProto :
249280 """Returns a standard ONNX op which is compatible with this CustomOp
250281 for performing shape inference."""
251282 pass
252283
253284 @abstractmethod
254- def infer_node_datatype (self , model ) :
285+ def infer_node_datatype (self , model : "ModelWrapper" ) -> None :
255286 """Set the DataType annotations corresponding to the outputs of this
256287 node."""
257288 pass
258289
259290 @abstractmethod
260- def execute_node (self , context , graph ) :
291+ def execute_node (self , context : dict [ str , npt . NDArray ], graph : GraphProto ) -> None :
261292 """Execute this CustomOp instance, given the execution context and
262293 ONNX graph."""
263294 pass
264295
265296 @abstractmethod
266- def verify_node (self ):
297+ def verify_node (self ) -> None :
267298 """Verifies that all attributes the node needs are there and
268299 that particular attributes are set correctly. Also checks if
269300 the number of inputs is equal to the expected number."""
0 commit comments