@@ -140,32 +140,41 @@ def __str__(self) -> str:
140140class Converter :
141141 """Main class to translate python code into ONNX operators.
142142
143- The class uses logger `onnxscript`. Logging can be enabled with the following code:
143+ The converter translates a Python function into an ONNX function by
144+ traversing the Python AST of the function and generating ONNX nodes
145+ that represent the operations in the Python code.
144146
145- ::
147+ ..tip ::
146148
147- import logging
148- logging.basicConfig(level=logging.DEBUG)
149+ The class uses logger `onnxscript`. Logging can be enabled with the following code:
149150
150- Or if you need to enable only the logger used by this module:
151+ ::
152+
153+ import logging
154+ logging.basicConfig(level=logging.DEBUG)
155+
156+ Or if you need to enable only the logger used by this module:
151157
152- ::
158+ ::
153159
154- import logging
155- logger = logging.getLogger('onnxscript')
156- logger.setLevel(logging.DEBUG)
157- console = logging.StreamHandler()
158- logger.addHandler(console)
160+ import logging
161+ logger = logging.getLogger('onnxscript')
162+ logger.setLevel(logging.DEBUG)
163+ console = logging.StreamHandler()
164+ logger.addHandler(console)
159165 """
160166
161167 def __init__ (
162168 self ,
169+ root : ast .FunctionDef ,
163170 opset : Optional [values .Opset ] = None ,
164171 global_names : Optional [dict [str , Any ]] = None ,
165172 source : Optional [str ] = None ,
166173 default_opset : Optional [values .Opset ] = None ,
167174 ):
168175 self ._source = source
176+ self ._root = root
177+
169178 if global_names is not None :
170179 # We make a copy in case function eval modifies it.
171180 self ._globals = global_names .copy ()
@@ -313,18 +322,6 @@ def generate_unique_name(self, candidate: str = "tmp") -> str:
313322 self ._used_vars .add (r )
314323 return r
315324
316- # def _make_onnx_attr(
317- # self, attrname: str, attrval: Any, attrtype: int | None = None
318- # ) -> irbuilder.IRAttributeValue:
319- # def tensor_name_generator() -> str:
320- # """Return name to be used for tensor, if we need to create one."""
321- # return self.generate_unique_name(f"attr_{attrname}")
322-
323- # proto = autocast.pyvalue_to_onnx_attribute(
324- # attrname, attrval, tensor_name_generator, attrtype
325- # )
326- # return self.ir_builder.make_attr(proto)
327-
328325 def _to_onnx_var (
329326 self ,
330327 val : values .SymbolValue | PyValue ,
@@ -497,10 +494,6 @@ def _translate_attr(
497494 )
498495 return attr
499496
500- def _translate_docstring (self , node : ast .FunctionDef ) -> None :
501- if docstring := ast .get_docstring (node ):
502- self ._current_fn .doc_string = docstring
503-
504497 def _translate_expr (
505498 self , node : ast .AST , target : Optional [PreferredName ] = None
506499 ) -> Variable :
@@ -1323,7 +1316,7 @@ def _translate_block(
13231316 def _translate_nested_function_def (self , fn : ast .FunctionDef ) -> None :
13241317 """Translate a nested function definition."""
13251318 self ._enter_scope (fn .name , fn )
1326- self ._translate_function_def_common (fn )
1319+ self ._translate_function_def (fn )
13271320 function_ir = self ._exit_scope ()
13281321 outer_scope_vars = analysis .outer_scope_variables (fn , self ._message )
13291322 function_ir .outer_scope_variables = [
@@ -1394,16 +1387,16 @@ def _translate_function_signature_common(
13941387
13951388 return self ._current_fn
13961389
1397- def _translate_function_def_common (self , node : ast .FunctionDef ) -> ir .Function :
1390+ def _translate_function_def (self , node : ast .FunctionDef ) -> ir .Function :
13981391 """Translate a function definition, including the signature and its body."""
1399- logger .debug ("Converter:_translate_function_def_common :%s" , node .name )
1392+ logger .debug ("Converter:_translate_function_def :%s" , node .name )
14001393 _ = self ._translate_function_signature_common (node )
14011394 for i , s in enumerate (node .body ):
14021395 self ._translate_stmt (s , index_of_stmt = i )
14031396
14041397 # Update docstring if available
14051398 if docstring := ast .get_docstring (node ):
1406- self ._current_fn .doc_string = docstring
1399+ self ._current_fn .doc_string = docstring
14071400 return self ._current_fn
14081401
14091402 def translate_function_def (self , stmt : ast .FunctionDef ) -> irbuilder .IRFunction :
@@ -1416,14 +1409,15 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
14161409 domain = self ._this_module .domain
14171410 self ._current_fn = self .ir_builder .new_function (stmt .name , domain , True )
14181411 analysis .do_liveness_analysis (stmt , self ._message )
1419- fn_ir = self ._translate_function_def_common (stmt )
1412+ fn_ir = self ._translate_function_def (stmt )
14201413 fn_ir .debug_print ()
14211414 self ._this_module .add_function_def (fn_ir )
14221415 return fn_ir
14231416 raise ValueError (f"Unsupported top-level statement type { type (stmt )!r} ." )
14241417
14251418 def translate_function_signature (self , fn : ast .FunctionDef ) -> irbuilder .IRFunction :
14261419 """Translate a (top-level) function signature."""
1420+ assert self ._this_module is not None
14271421 domain = self ._this_module .domain
14281422 self ._current_fn = self .ir_builder .new_function (fn .name , domain , True )
14291423 return self ._translate_function_signature_common (fn )
@@ -1534,6 +1528,8 @@ def _separate_inputs_and_attrs(
15341528
15351529def _to_onnx_ref_attr (val : values .AttrRef , info : sourceinfo .SourceInfo | None ) -> ir .Attr :
15361530 """Convert an attribute reference to an ONNX ref attribute."""
1531+
1532+ # TODO(justinchuby): Consider using a convenience function
15371533 pytype = val .typeinfo
15381534 attrtype = _schemas .get_attr_type (pytype )
15391535 attrname = None
0 commit comments