@@ -229,7 +229,7 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
229229 or opset .version != self ._default_opset .version
230230 ):
231231 self .fail (
232- node , f"Two distincts opset were used ({ opset } != { self ._default_opset } )."
232+ node , f"Two distinct opset were used ({ opset } != { self ._default_opset } )."
233233 )
234234 else :
235235 self ._default_opset = opset
@@ -251,19 +251,19 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
251251 return res
252252 return None
253253
254- def _init_function_translation (self ) -> None :
255- """Initialize self for translating a new (top-level) function."""
256- self ._outer = []
257- # TODO(justinchuby): Update this
258- self ._current_fn = ir .Function (
259- domain = self ._opset .domain ,
260- name = "" ,
261- graph = ir .Graph ((), (), nodes = []),
262- attributes = {},
263- )
264- self ._nextvar = 0
265- self ._used_vars = set ()
266- self ._locals : List [Dict [str , LocalSymValue ]] = [{}]
254+ # def _init_function_translation(self) -> None:
255+ # """Initialize self for translating a new (top-level) function."""
256+ # self._outer = []
257+ # # TODO(justinchuby): Update this
258+ # self._current_fn = ir.Function(
259+ # domain=self._opset.domain,
260+ # name="",
261+ # graph=ir.Graph((), (), nodes=[]),
262+ # attributes={},
263+ # )
264+ # self._nextvar = 0
265+ # self._used_vars = set()
266+ # self._locals: List[Dict[str, LocalSymValue]] = [{}]
267267
268268 def _source_of (self , node : ast .AST ) -> sourceinfo .SourceInfo :
269269 return sourceinfo .SourceInfo (node , self ._source , self ._current_fn .name )
@@ -328,7 +328,7 @@ def _lookup(
328328 raise ValueError (info .msg (f"Unbound name: { name } ." ))
329329 return None
330330
331- def generate_unique_name (self , candidate : str = "tmp" ) -> str :
331+ def _generate_unique_name (self , candidate : str = "tmp" ) -> str :
332332 # TODO(justinchuby): Can we reduce the O complexity of this function?
333333 r = candidate
334334 while r in self ._used_vars :
@@ -347,14 +347,14 @@ def _to_onnx_var(
347347 """Convert a value to an ONNX variable."""
348348 if isinstance (val , values .AttrRef ):
349349 # promote attribute to value
350- result = self .generate_unique_name (target )
350+ result = self ._generate_unique_name (target )
351351 attr = _to_onnx_ref_attr (val , info )
352352 self .emit ("Constant" , [], [result ], [attr ])
353353 if ta .base_type_is_bool (val .typeinfo ):
354354 # ONNX attributes use an int-encoding for bools, but ONNX tensor types
355355 # distinguish between int and bool. So we cast the int tensor to a bool tensor,
356356 # to promote a (python) bool attribute to a ONNX bool tensor.
357- result_as_bool = self .generate_unique_name (result + "_as_bool" )
357+ result_as_bool = self ._generate_unique_name (result + "_as_bool" )
358358 self .emit (
359359 "Cast" , [result ], [result_as_bool ], [ir .AttrInt64 ("to" , ir .DataType .BOOL )]
360360 )
@@ -406,7 +406,7 @@ def _emit_const(
406406 suggested_name = f"int64_{ pyvalue [0 ]} _1d"
407407 else :
408408 suggested_name = "const"
409- var_name = self .generate_unique_name (suggested_name )
409+ var_name = self ._generate_unique_name (suggested_name )
410410
411411 # Create a tensor from the python value
412412 try :
@@ -419,7 +419,7 @@ def _emit_const(
419419
420420 def _emit_copy (self , original_var : str , suggested_name : str ) -> str :
421421 """Emits a copy statement, using the ONNX Identity operator."""
422- new_var = self .generate_unique_name (suggested_name )
422+ new_var = self ._generate_unique_name (suggested_name )
423423 self .emit ("Identity" , [original_var ], [new_var ])
424424 return new_var
425425
@@ -539,7 +539,7 @@ def _translate_expr(
539539 callee , args , attrs = r
540540 target = "tmp" if target is None else target
541541 assert isinstance (target , str )
542- result = self .generate_unique_name (target )
542+ result = self ._generate_unique_name (target )
543543 self .emit ([result ], callee , args , attrs )
544544 return Variable (result )
545545
@@ -594,7 +594,7 @@ def _translate_subscript_expr(
594594 var_name = var .name
595595 if target is None :
596596 target = f"{ var_name } _subscripted"
597- target = self .generate_unique_name (target )
597+ target = self ._generate_unique_name (target )
598598 indices = ast_utils .normalize_subscript_expr (node )
599599 info = self ._source_of (node .slice )
600600
@@ -635,7 +635,7 @@ def translate_slice_component(
635635 raise RuntimeError (f"Slice component type must be int, not { type (cst )} " )
636636 else :
637637 name = self ._translate_expr (node_arg ).name
638- reshaped = self .generate_unique_name (f"{ name } _reshaped" )
638+ reshaped = self ._generate_unique_name (f"{ name } _reshaped" )
639639 self .emit (
640640 [reshaped ],
641641 values .Op (self .default_opset , "Reshape" ),
@@ -721,16 +721,16 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
721721
722722 if len (starts ) > 1 :
723723 axis_0_attr = self ._make_onnx_attr ("axis" , 0 )
724- start_name = self .generate_unique_name (f"{ var_name } _start" )
724+ start_name = self ._generate_unique_name (f"{ var_name } _start" )
725725 self .emit ([start_name ], "Concat" , starts , [axis_0_attr ])
726726
727- end_name = self .generate_unique_name (f"{ var_name } _end" )
727+ end_name = self ._generate_unique_name (f"{ var_name } _end" )
728728 self .emit ([end_name ], "Concat" , ends , [axis_0_attr ])
729729
730- axes_name = self .generate_unique_name (f"{ var_name } _axis" )
730+ axes_name = self ._generate_unique_name (f"{ var_name } _axis" )
731731 self .emit ([axes_name ], "Concat" , axes , [axis_0_attr ])
732732
733- steps_name = self .generate_unique_name (f"{ var_name } _step" )
733+ steps_name = self ._generate_unique_name (f"{ var_name } _step" )
734734 self .emit ([steps_name ], "Concat" , steps , [axis_0_attr ])
735735 else :
736736 start_name = starts [0 ]
@@ -739,7 +739,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
739739 steps_name = steps [0 ]
740740
741741 if squeezed_axes :
742- sliced_name = self .generate_unique_name (f"{ var_name } _sliced" )
742+ sliced_name = self ._generate_unique_name (f"{ var_name } _sliced" )
743743 self .emit (
744744 [sliced_name ],
745745 "Slice" ,
@@ -748,14 +748,14 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
748748 squeezed_axes = self ._emit_const (squeezed_axes , "squeezed_axes" , info )
749749
750750 if non_scalar_indices : # use temporary to store result of squeeze
751- result = self .generate_unique_name (f"{ var_name } _squeezed" )
751+ result = self ._generate_unique_name (f"{ var_name } _squeezed" )
752752 else : # store squeezed result in final target
753753 result = target
754754
755755 self .emit ([result ], "Squeeze" , [sliced_name , squeezed_axes ])
756756 else :
757757 if non_scalar_indices : # use temporary to store result of Slice
758- result = self .generate_unique_name (f"{ var_name } _sliced" )
758+ result = self ._generate_unique_name (f"{ var_name } _sliced" )
759759 else : # store result of Slice in final target
760760 result = target
761761 slice_inputs = [var_name , start_name , end_name , axes_name , steps_name ]
@@ -774,7 +774,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
774774 # use Gather to perform indexing
775775 # Assign gathered value to either temporary or final target
776776 if axis != last_axis : # use temporary to store result of Gather
777- gathered = self .generate_unique_name (f"{ var_name } _axis_{ axis } " )
777+ gathered = self ._generate_unique_name (f"{ var_name } _axis_{ axis } " )
778778 else : # store result of Gather in final target
779779 gathered = target
780780 self .emit ([gathered ], "Gather" , [str (result ), index_value ], [axis_attr ])
@@ -876,7 +876,7 @@ def _translate_compare_expr(self, node):
876876 op = values .Op (self .default_opset , opname if opname != "NotEqual" else "Equal" )
877877 left , right = self ._cast_like_binary_expression (op , left , right )
878878 if opname == "NotEqual" :
879- tmp = self .generate_unique_name ()
879+ tmp = self ._generate_unique_name ()
880880 self .emit ([tmp ], op , [left , right ])
881881 not_op = values .Op (self .default_opset , "Not" )
882882 return not_op , [tmp ], []
@@ -979,7 +979,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
979979 def generate_onnx_name (x : ast .AST ):
980980 if not isinstance (x , ast .Name ):
981981 self .fail (x , f"LHS must be a Name for unpacking, found: '{ type (x )!r} '" )
982- onnx_name = self .generate_unique_name (x .id )
982+ onnx_name = self ._generate_unique_name (x .id )
983983 self ._bind (
984984 x .id ,
985985 values .Dynamic (
@@ -1078,7 +1078,7 @@ def _translate_if_stmt(self, stmt: ast.If) -> None:
10781078 elseAttr = self ._make_onnx_attr ("else_branch" , elseGraph )
10791079
10801080 def rename (x ):
1081- r = self .generate_unique_name (x )
1081+ r = self ._generate_unique_name (x )
10821082 self ._bind (
10831083 x ,
10841084 values .Dynamic (r , values .DynamicKind .Intermediate , self ._source_of (stmt )),
@@ -1122,7 +1122,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11221122 self .fail (loop_stmt , "Unsupported loop bound, it should be 'range(?)'." )
11231123 assert not iter .keywords , "Unsupported loop bound."
11241124 o_loop_bound = self ._translate_expr (iter .args [0 ], "loop_bound" ).name
1125- o_cond_var = self .generate_unique_name ("cond_in" )
1125+ o_cond_var = self ._generate_unique_name ("cond_in" )
11261126 i_cond_var = o_cond_var
11271127 cond_while = None
11281128 o_loop_condition = "" # No condition for a for loop.
@@ -1156,7 +1156,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11561156
11571157 # build loop_body
11581158 self ._enter_scope ("loop_body" , loop_stmt )
1159- o_loop_var = self .generate_unique_name (p_loop_var )
1159+ o_loop_var = self ._generate_unique_name (p_loop_var )
11601160 self .ir_builder .add_input (
11611161 self ._current_fn ,
11621162 o_loop_var ,
@@ -1176,7 +1176,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11761176 )
11771177
11781178 for pv in loop_state_vars :
1179- ov = self .generate_unique_name (pv )
1179+ ov = self ._generate_unique_name (pv )
11801180 # TODO: retrieve the annotation for variable pv is any is specified.
11811181 # typeinfo = self._eval_constant_expr(pv.annotation)
11821182 typeinfo = None
@@ -1217,7 +1217,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12171217 continue
12181218 self ._translate_stmt (s )
12191219
1220- o_cond_out = self .generate_unique_name ("cond_out" )
1220+ o_cond_out = self ._generate_unique_name ("cond_out" )
12211221
12221222 if cond_while is not None :
12231223 # Loop while
@@ -1267,7 +1267,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12671267 info = self ._source_of (loop_stmt )
12681268
12691269 def rename (x ):
1270- r = self .generate_unique_name (x )
1270+ r = self ._generate_unique_name (x )
12711271 self ._bind (x , values .Dynamic (r , values .DynamicKind .Output , info ))
12721272 return r
12731273
0 commit comments