@@ -331,14 +331,17 @@ def _to_onnx_var(
331331 # promote attribute to value
332332 result = self ._generate_unique_name (target )
333333 attr = _to_onnx_ref_attr (val , info )
334- self .emit ([], "Constant" , [result ], [attr ])
334+ self .emit ([], "Constant" , [result ], attrs = [attr ])
335335 if ta .base_type_is_bool (val .typeinfo ):
336336 # ONNX attributes use an int-encoding for bools, but ONNX tensor types
337337 # distinguish between int and bool. So we cast the int tensor to a bool tensor,
338338 # to promote a (python) bool attribute to a ONNX bool tensor.
339339 result_as_bool = self ._generate_unique_name (result + "_as_bool" )
340340 self .emit (
341- [result ], "Cast" , [result_as_bool ], [ir .AttrInt64 ("to" , ir .DataType .BOOL )]
341+ [result ],
342+ "Cast" ,
343+ [result_as_bool ],
344+ attrs = [ir .AttrInt64 ("to" , ir .DataType .BOOL )],
342345 )
343346 return Variable (result_as_bool , castable = True )
344347 return Variable (result , castable = True )
@@ -358,6 +361,7 @@ def emit(
358361 outputs : Sequence [str ],
359362 op_type : str ,
360363 inputs : Sequence [str ],
364+ * ,
361365 attrs : Sequence [ir .Attr ] = (),
362366 domain : str = "" ,
363367 ):
@@ -396,7 +400,7 @@ def _emit_const(
396400 except Exception as e :
397401 fail (info .msg (str (e )))
398402
399- self .emit ([], "Constant" , [var_name ], [ir .AttrTensor ("value" , tensor )])
403+ self .emit ([], "Constant" , [var_name ], attrs = [ir .AttrTensor ("value" , tensor )])
400404 return Variable (var_name , True )
401405
402406 def _emit_copy (self , original_var : str , suggested_name : str ) -> str :
@@ -522,7 +526,7 @@ def _translate_expr(
522526 target = "tmp" if target is None else target
523527 assert isinstance (target , str )
524528 result = self ._generate_unique_name (target )
525- self .emit ([result ], callee , args , attrs )
529+ self .emit ([result ], callee , args , attrs = attrs )
526530 return Variable (result )
527531
528532 def _translate_opt_expr (self , node : ast .expr ) -> Optional [Variable ]:
@@ -620,9 +624,8 @@ def translate_slice_component(
620624 reshaped = self ._generate_unique_name (f"{ name } _reshaped" )
621625 self .emit (
622626 [reshaped ],
623- values . Op ( self . _default_opset , "Reshape" ) ,
627+ "Reshape" ,
624628 [name , one_1d ().name ],
625- [],
626629 )
627630 return reshaped , None
628631
@@ -704,16 +707,16 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
704707 if len (starts ) > 1 :
705708 axis_0_attr = self ._make_onnx_attr ("axis" , 0 )
706709 start_name = self ._generate_unique_name (f"{ var_name } _start" )
707- self .emit ([start_name ], "Concat" , starts , [axis_0_attr ])
710+ self .emit ([start_name ], "Concat" , starts , attrs = [axis_0_attr ])
708711
709712 end_name = self ._generate_unique_name (f"{ var_name } _end" )
710- self .emit ([end_name ], "Concat" , ends , [axis_0_attr ])
713+ self .emit ([end_name ], "Concat" , ends , attrs = [axis_0_attr ])
711714
712715 axes_name = self ._generate_unique_name (f"{ var_name } _axis" )
713- self .emit ([axes_name ], "Concat" , axes , [axis_0_attr ])
716+ self .emit ([axes_name ], "Concat" , axes , attrs = [axis_0_attr ])
714717
715718 steps_name = self ._generate_unique_name (f"{ var_name } _step" )
716- self .emit ([steps_name ], "Concat" , steps , [axis_0_attr ])
719+ self .emit ([steps_name ], "Concat" , steps , attrs = [axis_0_attr ])
717720 else :
718721 start_name = starts [0 ]
719722 end_name = ends [0 ]
@@ -759,7 +762,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
759762 gathered = self ._generate_unique_name (f"{ var_name } _axis_{ axis } " )
760763 else : # store result of Gather in final target
761764 gathered = target
762- self .emit ([gathered ], "Gather" , [str (result ), index_value ], [axis_attr ])
765+ self .emit ([gathered ], "Gather" , [str (result ), index_value ], attrs = [axis_attr ])
763766 result = gathered
764767
765768 return Variable (result )
@@ -971,7 +974,7 @@ def generate_onnx_name(x: ast.AST):
971974 return onnx_name
972975
973976 outputs = [generate_onnx_name (x ) for x in lhs .elts ]
974- self .emit (outputs , callee , inputs , attrs )
977+ self .emit (outputs , callee , inputs , attrs = attrs )
975978 else :
976979 self .fail (lhs , f"Unsupported construct in LHS of assignment: '{ type (lhs )!r} '" )
977980
@@ -1085,7 +1088,7 @@ def rename(x):
10851088 [test ],
10861089 "If" ,
10871090 renamed ,
1088- [then_attr , else_attr ],
1091+ attrs = [then_attr , else_attr ],
10891092 )
10901093
10911094 def _translate_loop_stmt (self , loop_stmt : Union [ast .For , ast .While ]) -> None :
@@ -1218,9 +1221,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12181221
12191222 self .emit (
12201223 [o_cond_out ],
1221- values . Op ( self . _default_opset , operator_name ) ,
1224+ operator_name ,
12221225 [condition_name or o_cond_var ],
1223- [],
12241226 )
12251227
12261228 self .ir_builder .add_output (
@@ -1262,8 +1264,8 @@ def rename(x):
12621264 onnx_outputs ,
12631265 "Loop" ,
12641266 inputs ,
1265- attrs ,
1266- sub_functions = sub_functions ,
1267+ attrs = attrs ,
1268+ # sub_functions=sub_functions,
12671269 )
12681270
12691271 def _translate_block (
0 commit comments