@@ -74,7 +74,7 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -
7474 self .name = varname
7575 self .info = sourceinfo
7676 self .typeinfo = typeinfo
77- if typeinfo is None :
77+ if typeinfo is None or not hasattr ( typeinfo , "to_type_proto" ) :
7878 self .value = ir .Value (name = varname )
7979 else :
8080 type_and_shape = ir .from_proto (typeinfo .to_type_proto ())
@@ -135,6 +135,7 @@ class IRAttributeParameter:
135135
136136 name : str
137137 type : onnx .AttributeProto .AttributeType
138+ attr : ir .Attr
138139 default_value : str | int | float | None = None
139140
140141 # TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
@@ -193,9 +194,8 @@ def debug_print(self):
193194 if logger .isEnabledFor (logging .DEBUG ):
194195 logger .debug ("%s: %s" , type (self ), self )
195196
196- def to_node_proto (self , node_name : str ) -> onnx .NodeProto :
197+ def to_node_proto (self ) -> onnx .NodeProto :
197198 n = ir .to_proto (self .node )
198- n .name = node_name
199199 return n
200200
201201 @property
@@ -208,8 +208,8 @@ class IRFunction:
208208 """Represents a function in the IR."""
209209
210210 def __init__ (self , name : str , domain : str = "" ) -> None :
211- self . ir_graph = ir .Graph (inputs = [], outputs = [], nodes = [], name = name )
212- self .domain = domain
211+ graph = ir .Graph (inputs = [], outputs = [], nodes = [], name = name )
212+ self .ir_function = ir . Function ( domain , name , graph = graph , attributes = [])
213213 self .outputs : list [IRVar ] = []
214214 self .stmts : list [IRStmt ] = []
215215 self .called_functions : dict [str , onnx .FunctionProto ] = {}
@@ -218,15 +218,20 @@ def __init__(self, name: str, domain: str = "") -> None:
218218 self .outer_scope_variables : dict [Any , Any ] = {}
219219 self .ordered_inputs_and_attrs : list [Union [IRVar , IRAttributeParameter ]] = []
220220
221+ @property
222+ def domain (self ) -> str :
223+ """Returns the domain of this function."""
224+ return self .ir_function .domain
225+
221226 @property
222227 def docstring (self ) -> str :
223228 """Returns the docstring of this function."""
224- return self .ir_graph .doc_string or ""
229+ return self .ir_function .doc_string or ""
225230
226231 @property
227232 def name (self ) -> str :
228233 """Returns the name of this function."""
229- return self .ir_graph .name
234+ return self .ir_function .name
230235
231236 @property
232237 def assigned_names (self ) -> Sequence [str ]:
@@ -253,16 +258,23 @@ def __str__(self):
253258 return f"{ self .name } { attrs } { inputs } => { outputs } { stmts } "
254259
255260 def append_stmt (self , stmt : IRStmt ) -> None :
261+ count = len (self .stmts )
262+ node_name = f"n{ count } "
263+ stmt .node .name = node_name
256264 self .stmts .append (stmt )
265+ self .ir_function .append (stmt .node )
257266
258- def append_input (self , name : IRVar ) -> None :
259- self .ordered_inputs_and_attrs .append (name )
267+ def append_input (self , var : IRVar ) -> None :
268+ self .ordered_inputs_and_attrs .append (var )
269+ self .ir_function .inputs .append (var .value )
260270
261- def append_output (self , name : IRVar ) -> None :
262- self .outputs .append (name )
271+ def append_output (self , var : IRVar ) -> None :
272+ self .outputs .append (var )
273+ self .ir_function .outputs .append (var .value )
263274
264275 def add_attr_parameter (self , attr : IRAttributeParameter ) -> None :
265276 self .ordered_inputs_and_attrs .append (attr )
277+ self .ir_function .attributes .add (attr .attr )
266278
267279 def debug_print (self ):
268280 if logger .isEnabledFor (logging .DEBUG ):
@@ -407,7 +419,7 @@ def _to_graph_and_functions(
407419 called_functions .update (s .functions )
408420 called_functions .update (self .called_functions )
409421 graph = helper .make_graph (
410- [s .to_node_proto (f"n { i } " ) for i , s in enumerate ( self .stmts ) ],
422+ [s .to_node_proto () for s in self .stmts ],
411423 self .name ,
412424 [x .to_value_info (use_default_type ) for x in self .inputs ],
413425 [y .to_value_info (use_default_type ) for y in self .outputs ],
@@ -450,7 +462,7 @@ def to_function_proto(self) -> onnx.FunctionProto:
450462 doesn't support it.
451463 """
452464 opsets = self .get_opset_import ()
453- nodes = [s .to_node_proto (f"n { i } " ) for i , s in enumerate ( self .stmts ) ]
465+ nodes = [s .to_node_proto () for s in self .stmts ]
454466 for n in nodes :
455467 if n .domain not in opsets :
456468 opsets [n .domain ] = 1 # TODO: how to get n.version?
@@ -494,7 +506,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I
494506 return function
495507
496508 def add_docstring (self , fn : IRFunction , docstring : str ):
497- fn .ir_graph .doc_string = docstring
509+ fn .ir_function .doc_string = docstring
498510
499511 def add_stmt (
500512 self ,
@@ -533,7 +545,10 @@ def add_attr_parameter(
533545 attribute_type : onnx .AttributeProto .AttributeType ,
534546 default_value : int | float | str | None ,
535547 ) -> None :
536- fn .add_attr_parameter (IRAttributeParameter (varname , attribute_type , default_value ))
548+ attr = ir .Attr (varname , ir .AttributeType (attribute_type ), None , None )
549+ fn .add_attr_parameter (
550+ IRAttributeParameter (varname , attribute_type , attr , default_value )
551+ )
537552
538553 def add_output (self , fn : IRFunction , varname : str , typeinfo , sourceinfo ) -> None :
539554 var = IRVar (varname , typeinfo , sourceinfo )
0 commit comments