Skip to content

Commit 98754ab

Browse files
committed
Update analysis pass
Signed-off-by: Justin Chu <[email protected]>
1 parent be610e9 commit 98754ab

File tree

3 files changed

+105
-101
lines changed

3 files changed

+105
-101
lines changed

onnxscript/_converter.py

Lines changed: 74 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from __future__ import annotations
66

77
import ast
8+
from collections import defaultdict
9+
import dataclasses
810
import logging
911
from typing import (
1012
TYPE_CHECKING,
@@ -26,7 +28,7 @@
2628
import onnxscript
2729
from onnxscript import irbuilder, onnx_types, sourceinfo, values
2830
from onnxscript import type_annotation as ta
29-
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation
31+
from onnxscript._internal import _analysis, ast_utils, autocast, param_manipulation
3032

3133
if TYPE_CHECKING:
3234
# The type-alias LocalSymValue represents the types of values that local names in a
@@ -137,6 +139,18 @@ def __str__(self) -> str:
137139
return self.name
138140

139141

142+
@dataclasses.dataclass
143+
class ASTMeta:
144+
"""Metadata for an AST node.
145+
146+
This class is used to store metadata about an AST node.
147+
"""
148+
149+
# For liveness analysis,
150+
live_out: set[ast.AST] = dataclasses.field(default_factory=set)
151+
live_in: set[ast.AST] = dataclasses.field(default_factory=set)
152+
153+
140154
class Converter:
141155
"""Main class to translate python code into ONNX operators.
142156
@@ -182,8 +196,11 @@ def __init__(
182196
source: Optional source code string for error reporting.
183197
default_opset: The default ONNX opset to use if no ONNX opset is specified in the script.
184198
"""
185-
186-
self._root = root
199+
if not isinstance(root, ast.FunctionDef):
200+
raise TypeError(
201+
f"Converter expects an AST FunctionDef node, got {type(root)}."
202+
)
203+
self._ast_root = root
187204
self._opset = opset
188205

189206
if global_names is not None:
@@ -193,7 +210,12 @@ def __init__(
193210
self._globals = {}
194211

195212
self._source = source
196-
self._default_opset = default_opset
213+
self._default_opset = default_opset or _find_onnx_opset(root, self._globals)
214+
if self._default_opset is None:
215+
raise ValueError(
216+
"default_opset must be specified in script for functions "
217+
"that do not contain any use of an ONNX opset."
218+
)
197219

198220
# TODO(justinchuby): Update ir version to be user defined
199221
# TODO(justinchuby): Maybe just store a list of functions
@@ -210,46 +232,8 @@ def __init__(
210232
self._nextvar: int = 0
211233
self._used_vars: set[str] = set()
212234
self._locals: list[dict[str, LocalSymValue]] = [{}]
213-
214-
@property
215-
def default_opset(self) -> values.Opset:
216-
if self._default_opset is None:
217-
raise RuntimeError(
218-
"default_opset must be specified in script for functions "
219-
"that do not contain any use of an ONNX opset."
220-
)
221-
return self._default_opset
222-
223-
def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
224-
if opset.domain != "":
225-
return
226-
if self._default_opset is not None:
227-
if (
228-
opset.domain != self._default_opset.domain
229-
or opset.version != self._default_opset.version
230-
):
231-
self.fail(
232-
node, f"Two distinct opset were used ({opset} != {self._default_opset})."
233-
)
234-
else:
235-
self._default_opset = opset
236-
237-
def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
238-
"""Find the (first) ONNX opset used in the function, if any."""
239-
# Search for a Call expression of form "op.OpName(...)"
240-
if isinstance(node, ast.Call):
241-
if isinstance(node.func, ast.Attribute):
242-
opset_expr = node.func.value
243-
if isinstance(opset_expr, ast.Name):
244-
if opset_expr.id in self._globals:
245-
opset = self._globals[opset_expr.id]
246-
if isinstance(opset, values.Opset) and opset.domain == "":
247-
return opset
248-
for child in ast.iter_child_nodes(node):
249-
res = self._find_onnx_opset(child)
250-
if res is not None:
251-
return res
252-
return None
235+
self._finalized = False
236+
self.meta: defaultdict[ast.AST, ASTMeta] = defaultdict(ASTMeta)
253237

254238
# def _init_function_translation(self) -> None:
255239
# """Initialize self for translating a new (top-level) function."""
@@ -638,7 +622,7 @@ def translate_slice_component(
638622
reshaped = self._generate_unique_name(f"{name}_reshaped")
639623
self.emit(
640624
[reshaped],
641-
values.Op(self.default_opset, "Reshape"),
625+
values.Op(self._default_opset, "Reshape"),
642626
[name, one_1d().name],
643627
[],
644628
)
@@ -827,7 +811,7 @@ def _translate_binary_op_expr(self, node: ast.BinOp):
827811
if isinstance(cst, float):
828812
attr = [self._make_onnx_attr("fmod", 1)]
829813

830-
op = values.Op(self.default_opset, _PRIMOP_MAP[op])
814+
op = values.Op(self._default_opset, _PRIMOP_MAP[op])
831815
left, right = self._cast_like_binary_expression(
832816
op, self._translate_expr(node.left), self._translate_expr(node.right)
833817
)
@@ -858,7 +842,7 @@ def _translate_unary_op_expr(self, node):
858842
return self._translate_expr(node.operand)
859843
opname = _PRIMOP_MAP[op]
860844
operand = self._translate_expr(node.operand)
861-
return values.Op(self.default_opset, opname), [operand], []
845+
return values.Op(self._default_opset, opname), [operand], []
862846

863847
def _translate_compare_expr(self, node):
864848
# TODO: handle multiple comparisons in one expression
@@ -873,12 +857,12 @@ def _translate_compare_expr(self, node):
873857

874858
# NotEqual is not a standard ONNX op, and needs to be translated into
875859
# an Equal op/node followed by a Not op/node.
876-
op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal")
860+
op = values.Op(self._default_opset, opname if opname != "NotEqual" else "Equal")
877861
left, right = self._cast_like_binary_expression(op, left, right)
878862
if opname == "NotEqual":
879863
tmp = self._generate_unique_name()
880864
self.emit([tmp], op, [left, right])
881-
not_op = values.Op(self.default_opset, "Not")
865+
not_op = values.Op(self._default_opset, "Not")
882866
return not_op, [tmp], []
883867

884868
return op, [left, right], []
@@ -918,12 +902,12 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable
918902
if isinstance(found, values.Op):
919903
return found
920904
if not found:
921-
if function_name not in self.default_opset:
905+
if function_name not in self._default_opset:
922906
warn(
923907
f"Unknown function name {function_name!r}. "
924908
f"The ONNX graph may not work."
925909
)
926-
return values.Op(self.default_opset, function_name)
910+
return values.Op(self._default_opset, function_name)
927911
self.fail(node, "Invalid callee")
928912

929913
def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
@@ -1062,10 +1046,10 @@ def ret(exp, i, suffix):
10621046
def _translate_if_stmt(self, stmt: ast.If) -> None:
10631047
if hasattr(stmt, "live_out"):
10641048
live_defs = list(
1065-
stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message))
1049+
stmt.live_out.intersection(_analysis.assigned_vars(stmt, self._message))
10661050
)
10671051
else:
1068-
live_defs = list(analysis.assigned_vars(stmt, self._message))
1052+
live_defs = list(_analysis.assigned_vars(stmt, self._message))
10691053
test = self._translate_expr(stmt.test, "cond").name
10701054
lineno = self._source_of(stmt).lineno
10711055
thenGraph, sub_fct_then = self._translate_block(
@@ -1097,7 +1081,7 @@ def rename(x):
10971081
self.fail(stmt, f"Input and output cannot be the same {renamed!r}.")
10981082
self.emit(
10991083
renamed,
1100-
values.Op(self.default_opset, "If"),
1084+
values.Op(self._default_opset, "If"),
11011085
[test],
11021086
[thenAttr, elseAttr],
11031087
sub_functions=sub_functions,
@@ -1145,8 +1129,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11451129
else:
11461130
self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.")
11471131
# analyze loop body
1148-
exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message)
1149-
vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message)
1132+
exposed_uses = _analysis.exposed_uses(loop_stmt.body, self._message)
1133+
vars_def_in_loop = _analysis.assigned_vars(loop_stmt.body, self._message)
11501134
loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out)
11511135
scan_outputs = set() # TODO
11521136
outputs = list(loop_state_vars | scan_outputs)
@@ -1232,7 +1216,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12321216

12331217
self.emit(
12341218
[o_cond_out],
1235-
values.Op(self.default_opset, operator_name),
1219+
values.Op(self._default_opset, operator_name),
12361220
[condition_name or o_cond_var],
12371221
[],
12381222
)
@@ -1333,7 +1317,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13331317
self._enter_scope(fn.name, fn)
13341318
self._translate_function_def(fn)
13351319
function_ir = self._exit_scope()
1336-
outer_scope_vars = analysis.outer_scope_variables(fn, self._message)
1320+
outer_scope_vars = _analysis.outer_scope_variables(fn, self._message)
13371321
function_ir.outer_scope_variables = [
13381322
(var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars
13391323
]
@@ -1343,7 +1327,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13431327

13441328
def _translate_function_signature_common(
13451329
self, fn: ast.FunctionDef
1346-
) -> irbuilder.IRFunction:
1330+
) -> ir.Function:
13471331
"""Translate a function signature (top-level or nested)."""
13481332
args = fn.args
13491333
if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg:
@@ -1414,28 +1398,19 @@ def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function:
14141398
self._current_fn.doc_string = docstring
14151399
return self._current_fn
14161400

1417-
def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
1418-
if isinstance(stmt, ast.FunctionDef):
1419-
self._init_function_translation()
1420-
if self._default_opset is None:
1421-
opset = self._find_onnx_opset(stmt)
1422-
if opset:
1423-
self._set_default_opset(opset, stmt)
1424-
domain = self._opset.domain
1425-
self._current_fn = self.ir_builder.new_function(stmt.name, domain, True)
1426-
analysis.do_liveness_analysis(stmt, self._message)
1427-
fn_ir = self._translate_function_def(stmt)
1428-
fn_ir.debug_print()
1429-
self._opset.add_function_def(fn_ir)
1430-
return fn_ir
1431-
raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.")
1432-
1433-
def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
1434-
"""Translate a (top-level) function signature."""
1435-
assert self._opset is not None
1436-
domain = self._opset.domain
1437-
self._current_fn = self.ir_builder.new_function(fn.name, domain, True)
1438-
return self._translate_function_signature_common(fn)
1401+
def _finalize(self) -> None:
1402+
self._finalized = True
1403+
1404+
def convert(self) -> ir.Function:
1405+
"""Convert the Python AST to an ONNX IR function."""
1406+
if self._finalized:
1407+
return self._current_fn
1408+
1409+
func_def = self._ast_root
1410+
_analysis.do_liveness_analysis(func_def, self._message, self.meta)
1411+
return self._translate_function_def(func_def)
1412+
# TODO(justinchuby): Handle function registration to the opset
1413+
# self._opset.add_function_def(fn_ir)
14391414

14401415

14411416
def _is_constant_expr(node: ast.AST) -> bool:
@@ -1561,3 +1536,21 @@ def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -
15611536
fail(info.msg(msg) if info else msg)
15621537
# TODO(justinchuby): What is the ref attr name?
15631538
return ir.RefAttr(attrname, val.value, attrtype)
1539+
1540+
1541+
def _find_onnx_opset(node: ast.AST, globals: dict[str, Any]) -> values.Opset | None:
1542+
"""Find the (first) ONNX opset used in the function, if any."""
1543+
# Search for a Call expression of form "op.OpName(...)"
1544+
if isinstance(node, ast.Call):
1545+
if isinstance(node.func, ast.Attribute):
1546+
opset_expr = node.func.value
1547+
if isinstance(opset_expr, ast.Name):
1548+
if opset_expr.id in globals:
1549+
opset = globals[opset_expr.id]
1550+
if isinstance(opset, values.Opset) and opset.domain == "":
1551+
return opset
1552+
for child in ast.iter_child_nodes(node):
1553+
res = _find_onnx_opset(child, globals)
1554+
if res is not None:
1555+
return res
1556+
return None
Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
"""Analysis utilities for Python AST."""
34
from __future__ import annotations
45

56
import ast
6-
from typing import Any, Optional, Sequence, Set
7+
from typing import Any, Optional, Sequence, TYPE_CHECKING
8+
from collections import defaultdict
79

810
from onnxscript import sourceinfo
911
from onnxscript._internal import ast_utils
1012

13+
if TYPE_CHECKING:
14+
from onnxscript import _converter
15+
1116

1217
def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:
1318
if not isinstance(for_stmt.target, ast.Name):
1419
raise TypeError(formatter(for_stmt, "For loop target must be a single variable."))
1520
return for_stmt.target.id
1621

1722

18-
def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
23+
def _used_vars(expr: Optional[ast.expr]) -> set[str]:
1924
"""Return set of all variables used, including function names, in an expression."""
2025
if expr is None:
2126
return set()
@@ -35,7 +40,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
3540
return result
3641

3742

38-
def _lhs_vars(lhs: ast.expr) -> Set[str]:
43+
def _lhs_vars(lhs: ast.expr) -> set[str]:
3944
"""Return set of assigned variables in the lhs of an assignment statement."""
4045

4146
def get_id(e):
@@ -49,12 +54,12 @@ def get_id(e):
4954

5055
def assigned_vars(
5156
stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter
52-
) -> Set[str]:
57+
) -> set[str]:
5358
"""Return the set of all variables that may be assigned to in an execution of input stmt
5459
or sequence of statements.
5560
"""
5661

57-
def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
62+
def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]:
5863
result: set[Any] = set()
5964
for s in block:
6065
result = result | assigned_vars(s, formatter)
@@ -84,20 +89,26 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
8489
raise ValueError(error_message)
8590

8691

87-
def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
88-
"""Perform liveness analysis of the given function-ast. The results of the
89-
analysis are stored directly with each statement-ast `s` as attributes `s.live_in`
90-
and `s.live_out`.
92+
def do_liveness_analysis(
93+
fun: ast.FunctionDef,
94+
formatter: sourceinfo.Formatter,
95+
meta: defaultdict[ast.AST, _converter.ASTMeta],
96+
):
97+
"""Perform liveness analysis of the given function-ast.
98+
99+
The results of the analysis are stored in the `meta` dictionary, which maps
100+
each AST node to its metadata. The metadata includes the set of live variables
101+
at the entry and exit of each node.
91102
"""
92103

93-
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
94-
stmt.live_out = live_out # type: ignore[attr-defined]
104+
def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
105+
meta[stmt].live_out = live_out
95106
live = do_visit(stmt, live_out)
96-
stmt.live_in = live # type: ignore[attr-defined]
107+
meta[stmt].live_in = live
97108
return live
98109

99-
def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
100-
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
110+
def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
111+
def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]:
101112
for s in reversed(block):
102113
live_out = visit(s, live_out)
103114
return live_out
@@ -165,12 +176,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter):
165176
(in the first statement). Hence x is included in the exposed_uses.
166177
"""
167178

168-
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
179+
def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]:
169180
for stmt in reversed(block):
170181
live_out = visit(stmt, live_out)
171182
return live_out
172183

173-
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
184+
def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
174185
if isinstance(stmt, ast.Assign):
175186
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value)
176187
if isinstance(stmt, ast.AnnAssign):

0 commit comments

Comments
 (0)