@@ -266,7 +266,7 @@ def _enter_scope(self, name: str, parent_node: ast.AST):
266266 The block is translated into a nested-scope in ONNX.
267267 """
268268 self ._outer .insert (0 , self ._current_fn )
269- self ._current_fn = self . new_function (name )
269+ self ._current_fn = irbuilder . IRFunction (name )
270270 self ._locals .insert (0 , {})
271271 logger .debug ("Converter:_enter_scope:%d:node:%s" , len (self ._locals ), type (parent_node ))
272272
@@ -379,22 +379,18 @@ def _to_onnx_var(
379379 def _py_var_to_onnx_var (self , py_var : str , info : sourceinfo .SourceInfo ) -> ir .Value :
380380 return self ._to_onnx_var (self ._lookup (py_var , info ), target = py_var , info = info )
381381
382- def new_function (self , name : str , domain : str = "" , register : bool = False ) -> irbuilder .IRFunction :
383- if register and (domain , name ) in self .ir_builder .functions :
384- raise RuntimeError (f"Function '{ name } ' already exists in domain '{ domain } '." )
385- function = irbuilder .IRFunction (name , domain )
386- if register :
387- self .ir_builder .functions [domain , name ] = function
388- return function
389-
390- def add_stmt (
382+ def emit (
391383 self ,
392- results : Sequence [str ],
393- callee : values .Op ,
384+ outputs : Sequence [str ],
385+ callee : values .Op | str ,
394386 inputs : Sequence [Optional [ir .Value ]],
395- attrs : Sequence [ir .Attr ],
396- ) -> Sequence [ir .Value ]:
397- output_values = [ir .Value (name = o ) for o in results ]
387+ attrs : Optional [Sequence [irbuilder .IRAttributeValue ]] = None ,
388+ ) -> Sequence [ir .Value ] | ir .Value :
389+ if not isinstance (callee , values .Op ):
390+ callee = values .Op (self .default_opset , callee )
391+ if attrs is None :
392+ attrs = []
393+ output_values = [ir .Value (name = o ) for o in outputs ]
398394 node = ir .Node (
399395 domain = callee .opset .domain ,
400396 version = callee .opset .version ,
@@ -407,50 +403,7 @@ def add_stmt(
407403 raise TypeError (f"Unexpected type { type (callee )} for callee." )
408404 node .meta .setdefault ("callee" , callee )
409405 self ._current_fn .append_node (node )
410- return output_values
411406
412- def add_attr_parameter (
413- self ,
414- varname : str ,
415- attribute_type : onnx .AttributeProto .AttributeType ,
416- default_value : int | float | str | None ,
417- ) -> None :
418- attr = ir .Attr (varname , ir .AttributeType (attribute_type ), default_value , None )
419- self ._current_fn .append_parameter (attr )
420-
421- def add_input (
422- self ,
423- varname : str ,
424- typeinfo : ta .TypeAnnotationValue ,
425- source_info : sourceinfo .SourceInfo ,
426- ) -> None :
427- self ._current_fn .append_parameter (make_value (varname , typeinfo , source_info ))
428-
429- def add_output (
430- self ,
431- varname : str ,
432- typeinfo : ta .TypeAnnotationValue ,
433- source_info : sourceinfo .SourceInfo ,
434- ) -> None :
435- self ._current_fn .append_output (make_value (varname , typeinfo , source_info ))
436-
437- def emit (
438- self ,
439- outputs : Sequence [str ],
440- callee : values .Op | str ,
441- inputs : Sequence [Optional [ir .Value ]],
442- attrs : Optional [Sequence [irbuilder .IRAttributeValue ]] = None ,
443- ) -> Sequence [ir .Value ] | ir .Value :
444- if not isinstance (callee , values .Op ):
445- callee = values .Op (self .default_opset , callee )
446- if attrs is None :
447- attrs = []
448- output_values = self .add_stmt (
449- outputs ,
450- callee ,
451- inputs ,
452- attrs ,
453- )
454407 return output_values if len (output_values ) > 1 else output_values [0 ]
455408
456409 def emit1 (self , * args , ** kwargs ) -> ir .Value :
@@ -1162,7 +1115,9 @@ def ret(exp, i, suffix):
11621115 t = None
11631116 else :
11641117 t = self .returntype [i ]
1165- self .add_output (return_var .name , t , self ._source_of (stmt ))
1118+ self ._current_fn .append_output (
1119+ make_value (return_var .name , t , self ._source_of (stmt ))
1120+ )
11661121 return return_var
11671122
11681123 val = stmt .value
@@ -1283,10 +1238,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12831238 # build loop_body
12841239 self ._enter_scope ("loop_body" , loop_stmt )
12851240 o_loop_var = self .generate_unique_name (p_loop_var )
1286- self .add_input (
1287- o_loop_var ,
1288- onnx_types .INT64 ,
1289- self ._source_of (loop_stmt ),
1241+ self ._current_fn .append_parameter (
1242+ make_value (
1243+ o_loop_var ,
1244+ onnx_types .INT64 ,
1245+ self ._source_of (loop_stmt ),
1246+ )
12901247 )
12911248 self ._bind (
12921249 p_loop_var ,
@@ -1295,18 +1252,22 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12951252 ),
12961253 )
12971254
1298- self .add_input (
1299- i_cond_var .name ,
1300- onnx_types .BOOL ,
1301- self ._source_of (loop_stmt ),
1255+ self ._current_fn .append_parameter (
1256+ make_value (
1257+ i_cond_var .name ,
1258+ onnx_types .BOOL ,
1259+ self ._source_of (loop_stmt ),
1260+ )
13021261 )
13031262
13041263 for pv in loop_state_vars :
13051264 ov = self .generate_unique_name (pv )
13061265 # TODO: retrieve the annotation for variable pv is any is specified.
13071266 # typeinfo = self._eval_constant_expr(pv.annotation)
13081267 typeinfo = None
1309- self .add_input (ov , typeinfo , self ._source_of (loop_stmt ))
1268+ self ._current_fn .append_parameter (
1269+ make_value (ov , typeinfo , self ._source_of (loop_stmt ))
1270+ )
13101271 self ._bind (
13111272 pv ,
13121273 values .Dynamic (
@@ -1363,10 +1324,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13631324 [],
13641325 )
13651326
1366- self .add_output (
1367- o_cond_out ,
1368- onnx_types .BOOL ,
1369- self ._source_of (loop_stmt ),
1327+ self ._current_fn .append_output (
1328+ make_value (
1329+ o_cond_out ,
1330+ onnx_types .BOOL ,
1331+ self ._source_of (loop_stmt ),
1332+ )
13701333 )
13711334 for pv in loop_state_vars :
13721335 ov = self ._py_var_to_onnx_var (pv , self ._source_of (loop_stmt ))
@@ -1379,7 +1342,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13791342 ov = self ._emit_copy (ov , pv )
13801343 # TODO: retrieve variable type for the annotation if any.
13811344 typeinfo = None
1382- self .add_output (ov .name , typeinfo , self ._source_of (loop_stmt ))
1345+ self ._current_fn .append_output (
1346+ make_value (ov .name , typeinfo , self ._source_of (loop_stmt ))
1347+ )
13831348 body = self ._exit_scope ()
13841349 inputs = [o_loop_bound , o_loop_condition ] + [
13851350 self ._py_var_to_onnx_var (pv , self ._source_of (loop_stmt )) for pv in loop_state_vars
@@ -1424,10 +1389,12 @@ def _translate_block(
14241389 # To return an outer-scope variable, an ONNX Graph has to
14251390 # use an explicit copy via Identity.
14261391 output = self ._emit_copy (output , pvar )
1427- self .add_output (
1428- output .name ,
1429- pv_val .typeinfo ,
1430- source ,
1392+ self ._current_fn .append_output (
1393+ make_value (
1394+ output .name ,
1395+ pv_val .typeinfo ,
1396+ source ,
1397+ )
14311398 )
14321399 else :
14331400 pv_val = None
@@ -1446,7 +1413,7 @@ def _translate_block(
14461413
14471414 # TODO: retrieve the annotation if any.
14481415 typeinfo = None
1449- self .add_output ( ovar .name , typeinfo , source )
1416+ self ._current_fn . append_output ( make_value ( ovar .name , typeinfo , source ) )
14501417 graph = self ._exit_scope ()
14511418 return graph .graph
14521419
@@ -1484,14 +1451,14 @@ def _translate_function_signature_common(
14841451 # The code can only be exported as a function.
14851452 typeinfo = None
14861453 if typeinfo and ta .is_attr_type (typeinfo ):
1487- self .add_attr_parameter (
1488- x .arg ,
1489- ta .pytype_to_attrtype (typeinfo ),
1490- default_value ,
1491- )
1454+ attribute_type = ta .pytype_to_attrtype (typeinfo )
1455+ attr = ir .Attr (x .arg , ir .AttributeType (attribute_type ), default_value , None )
1456+ self ._current_fn .append_parameter (attr )
14921457 self ._bind (x .arg , values .AttrRef (x .arg , typeinfo , self ._source_of (x )))
14931458 else :
1494- self .add_input (x .arg , typeinfo , self ._source_of (x ))
1459+ self ._current_fn .append_parameter (
1460+ make_value (x .arg , typeinfo , self ._source_of (x ))
1461+ )
14951462 self ._used_vars .add (x .arg )
14961463 self ._bind (
14971464 x .arg ,
@@ -1533,7 +1500,7 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
15331500 if opset :
15341501 self ._set_default_opset (opset , stmt )
15351502 domain = self .this_module .domain
1536- self ._current_fn = self . new_function (stmt .name , domain , True )
1503+ self ._current_fn = irbuilder . IRFunction (stmt .name , domain )
15371504 self ._analyzer = analysis .AstAnalyzer (stmt , self ._message , self .globals )
15381505 fn_ir = self ._translate_function_def_common (stmt )
15391506 self .this_module .add_function_def (fn_ir )
@@ -1544,5 +1511,5 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
15441511 def translate_function_signature (self , fn : ast .FunctionDef ) -> irbuilder .IRFunction :
15451512 """Translate a (top-level) function signature."""
15461513 domain = self .this_module .domain
1547- self ._current_fn = self . new_function (fn .name , domain , True )
1514+ self ._current_fn = irbuilder . IRFunction (fn .name , domain )
15481515 return self ._translate_function_signature_common (fn )
0 commit comments