55from __future__ import annotations
66
77import ast
8+ from collections import defaultdict
9+ import dataclasses
810import logging
911from typing import (
1012 TYPE_CHECKING ,
2628import onnxscript
2729from onnxscript import irbuilder , onnx_types , sourceinfo , values
2830from onnxscript import type_annotation as ta
29- from onnxscript ._internal import analysis , ast_utils , autocast , param_manipulation
31+ from onnxscript ._internal import _analysis , ast_utils , autocast , param_manipulation
3032
3133if TYPE_CHECKING :
3234 # The type-alias LocalSymValue represents the types of values that local names in a
@@ -137,6 +139,18 @@ def __str__(self) -> str:
137139 return self .name
138140
139141
142+ @dataclasses .dataclass
143+ class ASTMeta :
144+ """Metadata for an AST node.
145+
146+ This class is used to store metadata about an AST node.
147+ """
148+
149+ # For liveness analysis,
150+ live_out : set [ast .AST ] = dataclasses .field (default_factory = set )
151+ live_in : set [ast .AST ] = dataclasses .field (default_factory = set )
152+
153+
140154class Converter :
141155 """Main class to translate python code into ONNX operators.
142156
@@ -182,8 +196,11 @@ def __init__(
182196 source: Optional source code string for error reporting.
183197 default_opset: The default ONNX opset to use if no ONNX opset is specified in the script.
184198 """
185-
186- self ._root = root
199+ if not isinstance (root , ast .FunctionDef ):
200+ raise TypeError (
201+ f"Converter expects an AST FunctionDef node, got { type (root )} ."
202+ )
203+ self ._ast_root = root
187204 self ._opset = opset
188205
189206 if global_names is not None :
@@ -193,7 +210,12 @@ def __init__(
193210 self ._globals = {}
194211
195212 self ._source = source
196- self ._default_opset = default_opset
213+ self ._default_opset = default_opset or _find_onnx_opset (root , self ._globals )
214+ if self ._default_opset is None :
215+ raise ValueError (
216+ "default_opset must be specified in script for functions "
217+ "that do not contain any use of an ONNX opset."
218+ )
197219
198220 # TODO(justinchuby): Update ir version to be user defined
199221 # TODO(justinchuby): Maybe just store a list of functions
@@ -210,46 +232,8 @@ def __init__(
210232 self ._nextvar : int = 0
211233 self ._used_vars : set [str ] = set ()
212234 self ._locals : list [dict [str , LocalSymValue ]] = [{}]
213-
214- @property
215- def default_opset (self ) -> values .Opset :
216- if self ._default_opset is None :
217- raise RuntimeError (
218- "default_opset must be specified in script for functions "
219- "that do not contain any use of an ONNX opset."
220- )
221- return self ._default_opset
222-
223- def _set_default_opset (self , opset : values .Opset , node : ast .AST ) -> None :
224- if opset .domain != "" :
225- return
226- if self ._default_opset is not None :
227- if (
228- opset .domain != self ._default_opset .domain
229- or opset .version != self ._default_opset .version
230- ):
231- self .fail (
232- node , f"Two distinct opset were used ({ opset } != { self ._default_opset } )."
233- )
234- else :
235- self ._default_opset = opset
236-
237- def _find_onnx_opset (self , node : ast .AST ) -> Optional [values .Opset ]:
238- """Find the (first) ONNX opset used in the function, if any."""
239- # Search for a Call expression of form "op.OpName(...)"
240- if isinstance (node , ast .Call ):
241- if isinstance (node .func , ast .Attribute ):
242- opset_expr = node .func .value
243- if isinstance (opset_expr , ast .Name ):
244- if opset_expr .id in self ._globals :
245- opset = self ._globals [opset_expr .id ]
246- if isinstance (opset , values .Opset ) and opset .domain == "" :
247- return opset
248- for child in ast .iter_child_nodes (node ):
249- res = self ._find_onnx_opset (child )
250- if res is not None :
251- return res
252- return None
235+ self ._finalized = False
236+ self .meta : defaultdict [ast .AST , ASTMeta ] = defaultdict (ASTMeta )
253237
254238 # def _init_function_translation(self) -> None:
255239 # """Initialize self for translating a new (top-level) function."""
@@ -638,7 +622,7 @@ def translate_slice_component(
638622 reshaped = self ._generate_unique_name (f"{ name } _reshaped" )
639623 self .emit (
640624 [reshaped ],
641- values .Op (self .default_opset , "Reshape" ),
625+ values .Op (self ._default_opset , "Reshape" ),
642626 [name , one_1d ().name ],
643627 [],
644628 )
@@ -827,7 +811,7 @@ def _translate_binary_op_expr(self, node: ast.BinOp):
827811 if isinstance (cst , float ):
828812 attr = [self ._make_onnx_attr ("fmod" , 1 )]
829813
830- op = values .Op (self .default_opset , _PRIMOP_MAP [op ])
814+ op = values .Op (self ._default_opset , _PRIMOP_MAP [op ])
831815 left , right = self ._cast_like_binary_expression (
832816 op , self ._translate_expr (node .left ), self ._translate_expr (node .right )
833817 )
@@ -858,7 +842,7 @@ def _translate_unary_op_expr(self, node):
858842 return self ._translate_expr (node .operand )
859843 opname = _PRIMOP_MAP [op ]
860844 operand = self ._translate_expr (node .operand )
861- return values .Op (self .default_opset , opname ), [operand ], []
845+ return values .Op (self ._default_opset , opname ), [operand ], []
862846
863847 def _translate_compare_expr (self , node ):
864848 # TODO: handle multiple comparisons in one expression
@@ -873,12 +857,12 @@ def _translate_compare_expr(self, node):
873857
874858 # NotEqual is not a standard ONNX op, and needs to be translated into
875859 # an Equal op/node followed by a Not op/node.
876- op = values .Op (self .default_opset , opname if opname != "NotEqual" else "Equal" )
860+ op = values .Op (self ._default_opset , opname if opname != "NotEqual" else "Equal" )
877861 left , right = self ._cast_like_binary_expression (op , left , right )
878862 if opname == "NotEqual" :
879863 tmp = self ._generate_unique_name ()
880864 self .emit ([tmp ], op , [left , right ])
881- not_op = values .Op (self .default_opset , "Not" )
865+ not_op = values .Op (self ._default_opset , "Not" )
882866 return not_op , [tmp ], []
883867
884868 return op , [left , right ], []
@@ -918,12 +902,12 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable
918902 if isinstance (found , values .Op ):
919903 return found
920904 if not found :
921- if function_name not in self .default_opset :
905+ if function_name not in self ._default_opset :
922906 warn (
923907 f"Unknown function name { function_name !r} . "
924908 f"The ONNX graph may not work."
925909 )
926- return values .Op (self .default_opset , function_name )
910+ return values .Op (self ._default_opset , function_name )
927911 self .fail (node , "Invalid callee" )
928912
929913 def _translate_stmt (self , node : ast .stmt , index_of_stmt = None ) -> None :
@@ -1062,10 +1046,10 @@ def ret(exp, i, suffix):
10621046 def _translate_if_stmt (self , stmt : ast .If ) -> None :
10631047 if hasattr (stmt , "live_out" ):
10641048 live_defs = list (
1065- stmt .live_out .intersection (analysis .assigned_vars (stmt , self ._message ))
1049+ stmt .live_out .intersection (_analysis .assigned_vars (stmt , self ._message ))
10661050 )
10671051 else :
1068- live_defs = list (analysis .assigned_vars (stmt , self ._message ))
1052+ live_defs = list (_analysis .assigned_vars (stmt , self ._message ))
10691053 test = self ._translate_expr (stmt .test , "cond" ).name
10701054 lineno = self ._source_of (stmt ).lineno
10711055 thenGraph , sub_fct_then = self ._translate_block (
@@ -1097,7 +1081,7 @@ def rename(x):
10971081 self .fail (stmt , f"Input and output cannot be the same { renamed !r} ." )
10981082 self .emit (
10991083 renamed ,
1100- values .Op (self .default_opset , "If" ),
1084+ values .Op (self ._default_opset , "If" ),
11011085 [test ],
11021086 [thenAttr , elseAttr ],
11031087 sub_functions = sub_functions ,
@@ -1145,8 +1129,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11451129 else :
11461130 self .fail (loop_stmt , f"Unexpected loop type { type (loop_stmt )!r} ." )
11471131 # analyze loop body
1148- exposed_uses = analysis .exposed_uses (loop_stmt .body , self ._message )
1149- vars_def_in_loop = analysis .assigned_vars (loop_stmt .body , self ._message )
1132+ exposed_uses = _analysis .exposed_uses (loop_stmt .body , self ._message )
1133+ vars_def_in_loop = _analysis .assigned_vars (loop_stmt .body , self ._message )
11501134 loop_state_vars = vars_def_in_loop .intersection (exposed_uses | loop_stmt .live_out )
11511135 scan_outputs = set () # TODO
11521136 outputs = list (loop_state_vars | scan_outputs )
@@ -1232,7 +1216,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12321216
12331217 self .emit (
12341218 [o_cond_out ],
1235- values .Op (self .default_opset , operator_name ),
1219+ values .Op (self ._default_opset , operator_name ),
12361220 [condition_name or o_cond_var ],
12371221 [],
12381222 )
@@ -1333,7 +1317,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13331317 self ._enter_scope (fn .name , fn )
13341318 self ._translate_function_def (fn )
13351319 function_ir = self ._exit_scope ()
1336- outer_scope_vars = analysis .outer_scope_variables (fn , self ._message )
1320+ outer_scope_vars = _analysis .outer_scope_variables (fn , self ._message )
13371321 function_ir .outer_scope_variables = [
13381322 (var , self ._lookup (var , self ._source_of (fn ))) for var in outer_scope_vars
13391323 ]
@@ -1343,7 +1327,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13431327
13441328 def _translate_function_signature_common (
13451329 self , fn : ast .FunctionDef
1346- ) -> irbuilder . IRFunction :
1330+ ) -> ir . Function :
13471331 """Translate a function signature (top-level or nested)."""
13481332 args = fn .args
13491333 if args .vararg or args .kwonlyargs or args .kw_defaults or args .kwarg :
@@ -1414,28 +1398,19 @@ def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function:
14141398 self ._current_fn .doc_string = docstring
14151399 return self ._current_fn
14161400
1417- def translate_function_def (self , stmt : ast .FunctionDef ) -> irbuilder .IRFunction :
1418- if isinstance (stmt , ast .FunctionDef ):
1419- self ._init_function_translation ()
1420- if self ._default_opset is None :
1421- opset = self ._find_onnx_opset (stmt )
1422- if opset :
1423- self ._set_default_opset (opset , stmt )
1424- domain = self ._opset .domain
1425- self ._current_fn = self .ir_builder .new_function (stmt .name , domain , True )
1426- analysis .do_liveness_analysis (stmt , self ._message )
1427- fn_ir = self ._translate_function_def (stmt )
1428- fn_ir .debug_print ()
1429- self ._opset .add_function_def (fn_ir )
1430- return fn_ir
1431- raise ValueError (f"Unsupported top-level statement type { type (stmt )!r} ." )
1432-
1433- def translate_function_signature (self , fn : ast .FunctionDef ) -> irbuilder .IRFunction :
1434- """Translate a (top-level) function signature."""
1435- assert self ._opset is not None
1436- domain = self ._opset .domain
1437- self ._current_fn = self .ir_builder .new_function (fn .name , domain , True )
1438- return self ._translate_function_signature_common (fn )
1401+ def _finalize (self ) -> None :
1402+ self ._finalized = True
1403+
1404+ def convert (self ) -> ir .Function :
1405+ """Convert the Python AST to an ONNX IR function."""
1406+ if self ._finalized :
1407+ return self ._current_fn
1408+
1409+ func_def = self ._ast_root
1410+ _analysis .do_liveness_analysis (func_def , self ._message , self .meta )
1411+ return self ._translate_function_def (func_def )
1412+ # TODO(justinchuby): Handle function registration to the opset
1413+ # self._opset.add_function_def(fn_ir)
14391414
14401415
14411416def _is_constant_expr (node : ast .AST ) -> bool :
@@ -1561,3 +1536,21 @@ def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -
15611536 fail (info .msg (msg ) if info else msg )
15621537 # TODO(justinchuby): What is the ref attr name?
15631538 return ir .RefAttr (attrname , val .value , attrtype )
1539+
1540+
1541+ def _find_onnx_opset (node : ast .AST , globals : dict [str , Any ]) -> values .Opset | None :
1542+ """Find the (first) ONNX opset used in the function, if any."""
1543+ # Search for a Call expression of form "op.OpName(...)"
1544+ if isinstance (node , ast .Call ):
1545+ if isinstance (node .func , ast .Attribute ):
1546+ opset_expr = node .func .value
1547+ if isinstance (opset_expr , ast .Name ):
1548+ if opset_expr .id in globals :
1549+ opset = globals [opset_expr .id ]
1550+ if isinstance (opset , values .Opset ) and opset .domain == "" :
1551+ return opset
1552+ for child in ast .iter_child_nodes (node ):
1553+ res = _find_onnx_opset (child , globals )
1554+ if res is not None :
1555+ return res
1556+ return None
0 commit comments