Skip to content

Commit 882af66

Browse files
committed
wip
Signed-off-by: Justin Chu <[email protected]>
1 parent 761451a commit 882af66

File tree

1 file changed

+28
-32
lines changed

1 file changed

+28
-32
lines changed

onnxscript/_converter.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -140,32 +140,41 @@ def __str__(self) -> str:
140140
class Converter:
141141
"""Main class to translate python code into ONNX operators.
142142
143-
The class uses logger `onnxscript`. Logging can be enabled with the following code:
143+
The converter translates a Python function into an ONNX function by
144+
traversing the Python AST of the function and generating ONNX nodes
145+
that represent the operations in the Python code.
144146
145-
::
147+
..tip::
146148
147-
import logging
148-
logging.basicConfig(level=logging.DEBUG)
149+
The class uses logger `onnxscript`. Logging can be enabled with the following code:
149150
150-
Or if you need to enable only the logger used by this module:
151+
::
152+
153+
import logging
154+
logging.basicConfig(level=logging.DEBUG)
155+
156+
Or if you need to enable only the logger used by this module:
151157
152-
::
158+
::
153159
154-
import logging
155-
logger = logging.getLogger('onnxscript')
156-
logger.setLevel(logging.DEBUG)
157-
console = logging.StreamHandler()
158-
logger.addHandler(console)
160+
import logging
161+
logger = logging.getLogger('onnxscript')
162+
logger.setLevel(logging.DEBUG)
163+
console = logging.StreamHandler()
164+
logger.addHandler(console)
159165
"""
160166

161167
def __init__(
162168
self,
169+
root: ast.FunctionDef,
163170
opset: Optional[values.Opset] = None,
164171
global_names: Optional[dict[str, Any]] = None,
165172
source: Optional[str] = None,
166173
default_opset: Optional[values.Opset] = None,
167174
):
168175
self._source = source
176+
self._root = root
177+
169178
if global_names is not None:
170179
# We make a copy in case function eval modifies it.
171180
self._globals = global_names.copy()
@@ -313,18 +322,6 @@ def generate_unique_name(self, candidate: str = "tmp") -> str:
313322
self._used_vars.add(r)
314323
return r
315324

316-
# def _make_onnx_attr(
317-
# self, attrname: str, attrval: Any, attrtype: int | None = None
318-
# ) -> irbuilder.IRAttributeValue:
319-
# def tensor_name_generator() -> str:
320-
# """Return name to be used for tensor, if we need to create one."""
321-
# return self.generate_unique_name(f"attr_{attrname}")
322-
323-
# proto = autocast.pyvalue_to_onnx_attribute(
324-
# attrname, attrval, tensor_name_generator, attrtype
325-
# )
326-
# return self.ir_builder.make_attr(proto)
327-
328325
def _to_onnx_var(
329326
self,
330327
val: values.SymbolValue | PyValue,
@@ -497,10 +494,6 @@ def _translate_attr(
497494
)
498495
return attr
499496

500-
def _translate_docstring(self, node: ast.FunctionDef) -> None:
501-
if docstring := ast.get_docstring(node):
502-
self._current_fn.doc_string = docstring
503-
504497
def _translate_expr(
505498
self, node: ast.AST, target: Optional[PreferredName] = None
506499
) -> Variable:
@@ -1323,7 +1316,7 @@ def _translate_block(
13231316
def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13241317
"""Translate a nested function definition."""
13251318
self._enter_scope(fn.name, fn)
1326-
self._translate_function_def_common(fn)
1319+
self._translate_function_def(fn)
13271320
function_ir = self._exit_scope()
13281321
outer_scope_vars = analysis.outer_scope_variables(fn, self._message)
13291322
function_ir.outer_scope_variables = [
@@ -1394,16 +1387,16 @@ def _translate_function_signature_common(
13941387

13951388
return self._current_fn
13961389

1397-
def _translate_function_def_common(self, node: ast.FunctionDef) -> ir.Function:
1390+
def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function:
13981391
"""Translate a function definition, including the signature and its body."""
1399-
logger.debug("Converter:_translate_function_def_common:%s", node.name)
1392+
logger.debug("Converter:_translate_function_def:%s", node.name)
14001393
_ = self._translate_function_signature_common(node)
14011394
for i, s in enumerate(node.body):
14021395
self._translate_stmt(s, index_of_stmt=i)
14031396

14041397
# Update docstring if available
14051398
if docstring := ast.get_docstring(node):
1406-
self._current_fn.doc_string = docstring
1399+
self._current_fn.doc_string = docstring
14071400
return self._current_fn
14081401

14091402
def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
@@ -1416,14 +1409,15 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
14161409
domain = self._this_module.domain
14171410
self._current_fn = self.ir_builder.new_function(stmt.name, domain, True)
14181411
analysis.do_liveness_analysis(stmt, self._message)
1419-
fn_ir = self._translate_function_def_common(stmt)
1412+
fn_ir = self._translate_function_def(stmt)
14201413
fn_ir.debug_print()
14211414
self._this_module.add_function_def(fn_ir)
14221415
return fn_ir
14231416
raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.")
14241417

14251418
def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
14261419
"""Translate a (top-level) function signature."""
1420+
assert self._this_module is not None
14271421
domain = self._this_module.domain
14281422
self._current_fn = self.ir_builder.new_function(fn.name, domain, True)
14291423
return self._translate_function_signature_common(fn)
@@ -1534,6 +1528,8 @@ def _separate_inputs_and_attrs(
15341528

15351529
def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -> ir.Attr:
15361530
"""Convert an attribute reference to an ONNX ref attribute."""
1531+
1532+
# TODO(justinchuby): Consider using a convenience function
15371533
pytype = val.typeinfo
15381534
attrtype = _schemas.get_attr_type(pytype)
15391535
attrname = None

0 commit comments

Comments
 (0)