Skip to content

Commit 8b5fcff

Browse files
committed
Remove formatter argument
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent e00358f commit 8b5fcff

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

onnxscript/_internal/analysis.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self._constant_if_condition: dict[ast.If, bool] = {}
5959
if globals:
6060
self._compute_constant_if_conditions(fun, globals)
61-
self.do_liveness_analysis(fun, formatter)
61+
self.do_liveness_analysis(fun)
6262

6363
def _compute_constant_if_conditions(
6464
self, fun: ast.FunctionDef, globals: dict[str, Any]
@@ -70,7 +70,7 @@ def _compute_constant_if_conditions(
7070
conditions. The value of such conditions is determined from the outer-scope.
7171
"""
7272

73-
assigned_vars = self.assigned_vars(fun.body, self._formatter)
73+
assigned_vars = self.assigned_vars(fun.body)
7474
for node in ast.walk(fun):
7575
if isinstance(node, ast.If):
7676
if isinstance(node.test, ast.Name):
@@ -90,17 +90,15 @@ def constant_if_condition(self, if_stmt: ast.If) -> Optional[bool]:
9090
"""
9191
return self._constant_if_condition.get(if_stmt, None) # type: ignore[return-value]
9292

93-
def assigned_vars(
94-
self, stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter
95-
) -> Set[str]:
93+
def assigned_vars(self, stmt: ast.stmt | list[ast.stmt]) -> Set[str]:
9694
"""Return the set of all variables that may be assigned to in an execution of input stmt
9795
or sequence of statements.
9896
"""
9997

10098
def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
10199
result: set[Any] = set()
102100
for s in block:
103-
result = result | self.assigned_vars(s, formatter)
101+
result = result | self.assigned_vars(s)
104102
return result
105103

106104
if isinstance(stmt, ast.Assign):
@@ -118,7 +116,7 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
118116
else:
119117
return assigned_in_block(stmt.orelse)
120118
if isinstance(stmt, ast.For):
121-
return assigned_in_block(stmt.body) | {_get_loop_var(stmt, formatter)}
119+
return assigned_in_block(stmt.body) | {_get_loop_var(stmt, self._formatter)}
122120
if isinstance(stmt, ast.While):
123121
return assigned_in_block(stmt.body)
124122
if isinstance(stmt, list):
@@ -133,10 +131,10 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
133131
return set()
134132
if ast_utils.is_doc_string(stmt):
135133
return set()
136-
error_message = formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
134+
error_message = self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
137135
raise ValueError(error_message)
138136

139-
def do_liveness_analysis(self, fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
137+
def do_liveness_analysis(self, fun: ast.FunctionDef):
140138
"""Perform liveness analysis of the given function-ast. The results of the
141139
analysis are stored directly with each statement-ast `s` as attributes `s.live_in`
142140
and `s.live_out`.
@@ -171,7 +169,7 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
171169
else:
172170
return visitBlock(stmt.orelse, live_out)
173171
if isinstance(stmt, ast.For):
174-
p_loop_var = _get_loop_var(stmt, formatter)
172+
p_loop_var = _get_loop_var(stmt, self._formatter)
175173
prev = None
176174
curr = live_out
177175
while curr != prev:
@@ -198,14 +196,16 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
198196
return live_out
199197
if ast_utils.is_print_call(stmt):
200198
return live_out
201-
raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}."))
199+
raise ValueError(
200+
self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
201+
)
202202

203203
assert isinstance(fun, ast.FunctionDef)
204204
live: set[Any] = set()
205205
for s in reversed(fun.body):
206206
live = visit(s, live)
207207

208-
def exposed_uses(self, stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter):
208+
def exposed_uses(self, stmts: Sequence[ast.stmt]):
209209
"""Return the set of variables that are used before being defined by given block.
210210
In essence, this identifies the "inputs" to a given code-block.
211211
For example, consider the following code-block:
@@ -251,7 +251,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
251251
if isinstance(stmt, ast.For):
252252
# Analysis assumes loop may execute zero times. Results can be improved
253253
# for loops that execute at least once.
254-
loop_var_set = {_get_loop_var(stmt, formatter)}
254+
loop_var_set = {_get_loop_var(stmt, self._formatter)}
255255
used_after_loop = live_out.difference(loop_var_set)
256256
used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set)
257257
used_in_loop_header = _used_vars(stmt.iter)
@@ -269,13 +269,15 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
269269
if isinstance(stmt, ast.FunctionDef):
270270
if stmt.name in live_out:
271271
live_out.remove(stmt.name)
272-
live_out = live_out | self.outer_scope_variables(stmt, formatter)
272+
live_out = live_out | self.outer_scope_variables(stmt)
273273
return live_out
274-
raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}."))
274+
raise ValueError(
275+
self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
276+
)
275277

276278
return visitBlock(stmts, set())
277279

278-
def outer_scope_variables(self, fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
280+
def outer_scope_variables(self, fun: ast.FunctionDef):
279281
"""Return the set of outer-scope variables used in a nested function.
280282
281283
Args:
@@ -286,6 +288,6 @@ def outer_scope_variables(self, fun: ast.FunctionDef, formatter: sourceinfo.Form
286288
A set of variable names (strings).
287289
"""
288290
assert isinstance(fun, ast.FunctionDef)
289-
used_vars_ = self.exposed_uses(fun.body, formatter)
291+
used_vars_ = self.exposed_uses(fun.body)
290292
inputs = [x.arg for x in fun.args.args]
291293
return used_vars_.difference(inputs)

onnxscript/_internal/analysis_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class TestExposedUses(unittest.TestCase):
114114
def assertUses(self, f, expected):
115115
source, parse_tree = ast_utils.get_src_and_ast(f)
116116
analyzer = analysis.AstAnalyzer(parse_tree, formatter(source))
117-
result = analyzer.exposed_uses(parse_tree.body, formatter(source))
117+
result = analyzer.exposed_uses(parse_tree.body)
118118
self.assertEqual(result, set(expected))
119119

120120
def test_basic(self):
@@ -192,7 +192,7 @@ class TestAssignedVarAnalysis(unittest.TestCase):
192192
def assert_assigned_vars(self, f, expected: set[str]):
193193
source, parse_tree = ast_utils.get_src_and_ast(f)
194194
analyzer = analysis.AstAnalyzer(parse_tree, formatter(source))
195-
result = analyzer.assigned_vars(parse_tree.body, formatter(source))
195+
result = analyzer.assigned_vars(parse_tree.body)
196196
self.assertEqual(result, expected)
197197

198198
def test_basic_defs(self):

onnxscript/converter.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,11 +1108,9 @@ def _translate_if_stmt(self, stmt: ast.If) -> None:
11081108
self._translate_stmt(s)
11091109
return
11101110
if hasattr(stmt, "live_out"):
1111-
live_defs = list(
1112-
stmt.live_out.intersection(self.analyzer.assigned_vars(stmt, self._message))
1113-
)
1111+
live_defs = list(stmt.live_out.intersection(self.analyzer.assigned_vars(stmt)))
11141112
else:
1115-
live_defs = list(self.analyzer.assigned_vars(stmt, self._message))
1113+
live_defs = list(self.analyzer.assigned_vars(stmt))
11161114
test = self._translate_expr(stmt.test, "cond").name
11171115
lineno = self._source_of(stmt).lineno
11181116
thenGraph, sub_fct_then = self._translate_block(
@@ -1192,8 +1190,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11921190
else:
11931191
self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.")
11941192
# analyze loop body
1195-
exposed_uses = self.analyzer.exposed_uses(loop_stmt.body, self._message)
1196-
vars_def_in_loop = self.analyzer.assigned_vars(loop_stmt.body, self._message)
1193+
exposed_uses = self.analyzer.exposed_uses(loop_stmt.body)
1194+
vars_def_in_loop = self.analyzer.assigned_vars(loop_stmt.body)
11971195
loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out)
11981196
scan_outputs = set() # TODO
11991197
outputs = list(loop_state_vars | scan_outputs)
@@ -1380,7 +1378,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13801378
self._enter_scope(fn.name, fn)
13811379
self._translate_function_def_common(fn)
13821380
function_ir = self._exit_scope()
1383-
outer_scope_vars = self.analyzer.outer_scope_variables(fn, self._message)
1381+
outer_scope_vars = self.analyzer.outer_scope_variables(fn)
13841382
function_ir.outer_scope_variables = [
13851383
(var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars
13861384
]

0 commit comments

Comments
 (0)