Skip to content

Commit 852cc42

Browse files
committed
update
Signed-off-by: Justin Chu <[email protected]>
1 parent 882af66 commit 852cc42

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

onnxscript/_converter.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,27 +167,42 @@ class Converter:
167167
def __init__(
168168
self,
169169
root: ast.FunctionDef,
170+
*,
170171
opset: Optional[values.Opset] = None,
171172
global_names: Optional[dict[str, Any]] = None,
172173
source: Optional[str] = None,
173174
default_opset: Optional[values.Opset] = None,
174175
):
175-
self._source = source
176+
"""Initialize the converter.
177+
178+
Args:
179+
root: The root AST node of the function to be converted.
180+
opset: The ONNX opset to use for the conversion. If None, the default opset is used.
181+
global_names: A dictionary of global names available in the script.
182+
source: Optional source code string for error reporting.
183+
default_opset: The default ONNX opset to use if no ONNX opset is specified in the script.
184+
"""
185+
176186
self._root = root
187+
self._opset = opset
177188

178189
if global_names is not None:
179190
# We make a copy in case function eval modifies it.
180191
self._globals = global_names.copy()
181-
self._this_module = opset
192+
else:
193+
self._globals = {}
194+
195+
self._source = source
182196
self._default_opset = default_opset
183197

184198
# TODO(justinchuby): Update ir version to be user defined
199+
# TODO(justinchuby): Maybe just store a list of functions
185200
self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10)
186201

187202
# A stack of functions in the outer scope
188203
self._outer: list[ir.Function] = []
189204
self._current_fn: ir.Function = ir.Function(
190-
domain=self._this_module.domain,
205+
domain=self._opset.domain,
191206
name="",
192207
graph=ir.Graph((), (), nodes=[]),
193208
attributes={},
@@ -241,7 +256,7 @@ def _init_function_translation(self) -> None:
241256
self._outer = []
242257
# TODO(justinchuby): Update this
243258
self._current_fn = ir.Function(
244-
domain=self._this_module.domain,
259+
domain=self._opset.domain,
245260
name="",
246261
graph=ir.Graph((), (), nodes=[]),
247262
attributes={},
@@ -275,9 +290,9 @@ def _enter_scope(self, name: str, parent_node: ast.AST):
275290
The block is translated into a nested-scope in ONNX.
276291
"""
277292
self._outer.append(self._current_fn)
278-
assert self._this_module is not None
293+
assert self._opset is not None
279294
self._current_fn = ir.Function(
280-
domain=self._this_module.domain,
295+
domain=self._opset.domain,
281296
name=name,
282297
graph=ir.Graph((), (), nodes=[]),
283298
attributes={},
@@ -1406,19 +1421,19 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
14061421
opset = self._find_onnx_opset(stmt)
14071422
if opset:
14081423
self._set_default_opset(opset, stmt)
1409-
domain = self._this_module.domain
1424+
domain = self._opset.domain
14101425
self._current_fn = self.ir_builder.new_function(stmt.name, domain, True)
14111426
analysis.do_liveness_analysis(stmt, self._message)
14121427
fn_ir = self._translate_function_def(stmt)
14131428
fn_ir.debug_print()
1414-
self._this_module.add_function_def(fn_ir)
1429+
self._opset.add_function_def(fn_ir)
14151430
return fn_ir
14161431
raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.")
14171432

14181433
def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
14191434
"""Translate a (top-level) function signature."""
1420-
assert self._this_module is not None
1421-
domain = self._this_module.domain
1435+
assert self._opset is not None
1436+
domain = self._opset.domain
14221437
self._current_fn = self.ir_builder.new_function(fn.name, domain, True)
14231438
return self._translate_function_signature_common(fn)
14241439

0 commit comments

Comments
 (0)