@@ -167,27 +167,42 @@ class Converter:
167167 def __init__ (
168168 self ,
169169 root : ast .FunctionDef ,
170+ * ,
170171 opset : Optional [values .Opset ] = None ,
171172 global_names : Optional [dict [str , Any ]] = None ,
172173 source : Optional [str ] = None ,
173174 default_opset : Optional [values .Opset ] = None ,
174175 ):
175- self ._source = source
176+ """Initialize the converter.
177+
178+ Args:
179+ root: The root AST node of the function to be converted.
180+ opset: The ONNX opset to use for the conversion. If None, the default opset is used.
181+ global_names: A dictionary of global names available in the script.
182+ source: Optional source code string for error reporting.
183+ default_opset: The default ONNX opset to use if no ONNX opset is specified in the script.
184+ """
185+
176186 self ._root = root
187+ self ._opset = opset
177188
178189 if global_names is not None :
179190 # We make a copy in case function eval modifies it.
180191 self ._globals = global_names .copy ()
181- self ._this_module = opset
192+ else :
193+ self ._globals = {}
194+
195+ self ._source = source
182196 self ._default_opset = default_opset
183197
184198 # TODO(justinchuby): Update ir version to be user defined
199+ # TODO(justinchuby): Maybe just store a list of functions
185200 self ._model = ir .Model (ir .Graph ((), (), nodes = ()), ir_version = 10 )
186201
187202 # A stack of functions in the outer scope
188203 self ._outer : list [ir .Function ] = []
189204 self ._current_fn : ir .Function = ir .Function (
190- domain = self ._this_module .domain ,
205+ domain = self ._opset .domain ,
191206 name = "" ,
192207 graph = ir .Graph ((), (), nodes = []),
193208 attributes = {},
@@ -241,7 +256,7 @@ def _init_function_translation(self) -> None:
241256 self ._outer = []
242257 # TODO(justinchuby): Update this
243258 self ._current_fn = ir .Function (
244- domain = self ._this_module .domain ,
259+ domain = self ._opset .domain ,
245260 name = "" ,
246261 graph = ir .Graph ((), (), nodes = []),
247262 attributes = {},
@@ -275,9 +290,9 @@ def _enter_scope(self, name: str, parent_node: ast.AST):
275290 The block is translated into a nested-scope in ONNX.
276291 """
277292 self ._outer .append (self ._current_fn )
278- assert self ._this_module is not None
293+ assert self ._opset is not None
279294 self ._current_fn = ir .Function (
280- domain = self ._this_module .domain ,
295+ domain = self ._opset .domain ,
281296 name = name ,
282297 graph = ir .Graph ((), (), nodes = []),
283298 attributes = {},
@@ -1406,19 +1421,19 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
14061421 opset = self ._find_onnx_opset (stmt )
14071422 if opset :
14081423 self ._set_default_opset (opset , stmt )
1409- domain = self ._this_module .domain
1424+ domain = self ._opset .domain
14101425 self ._current_fn = self .ir_builder .new_function (stmt .name , domain , True )
14111426 analysis .do_liveness_analysis (stmt , self ._message )
14121427 fn_ir = self ._translate_function_def (stmt )
14131428 fn_ir .debug_print ()
1414- self ._this_module .add_function_def (fn_ir )
1429+ self ._opset .add_function_def (fn_ir )
14151430 return fn_ir
14161431 raise ValueError (f"Unsupported top-level statement type { type (stmt )!r} ." )
14171432
14181433 def translate_function_signature (self , fn : ast .FunctionDef ) -> irbuilder .IRFunction :
14191434 """Translate a (top-level) function signature."""
1420- assert self ._this_module is not None
1421- domain = self ._this_module .domain
1435+ assert self ._opset is not None
1436+ domain = self ._opset .domain
14221437 self ._current_fn = self .ir_builder .new_function (fn .name , domain , True )
14231438 return self ._translate_function_signature_common (fn )
14241439
0 commit comments