Skip to content

Commit 1294bbe

Browse files
committed
use "explicit" if stack
1 parent e3cea52 commit 1294bbe

File tree

6 files changed

+2211
-665
lines changed

6 files changed

+2211
-665
lines changed

mlir_utils/ast/canonicalize.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import difflib
23
import enum
34
import inspect
45
import logging
@@ -9,39 +10,91 @@
910

1011
import libcst as cst
1112
from bytecode import ConcreteBytecode
13+
from libcst._position import CodeRange, CodePosition
1214
from libcst.matchers import MatcherDecoratableTransformer
15+
from libcst.metadata import (
16+
PositionProvider,
17+
ParentNodeProvider,
18+
QualifiedNameProvider,
19+
ExperimentalReentrantCodegenProvider,
20+
)
1321

1422
from mlir_utils.ast.util import get_module_cst, copy_func
1523

1624
logger = logging.getLogger(__name__)
1725

1826

1927
class Transformer(MatcherDecoratableTransformer):
20-
def __init__(self, context):
28+
METADATA_DEPENDENCIES = (
29+
PositionProvider,
30+
ParentNodeProvider,
31+
QualifiedNameProvider,
32+
ExperimentalReentrantCodegenProvider,
33+
)
34+
35+
def __init__(self, context, first_lineno):
2136
super().__init__()
2237
self.context = context
38+
self.first_lineno = first_lineno
39+
40+
def get_pos(self, node):
41+
pos = self.get_metadata(PositionProvider, node)
42+
return CodeRange(
43+
CodePosition(pos.start.line + self.first_lineno, pos.start.column),
44+
CodePosition(pos.end.line + self.first_lineno, pos.end.column),
45+
)
46+
47+
def get_parent(self, node):
48+
# NB: can only call this on "original nodes"
49+
return self.get_metadata(ParentNodeProvider, node)
50+
51+
def get_code_snippet(self, node):
52+
return self.get_metadata(
53+
ExperimentalReentrantCodegenProvider, node
54+
).get_original_statement_code()
2355

2456

2557
class StrictTransformer(Transformer):
2658
def visit_FunctionDef(self, node: cst.FunctionDef):
2759
return False
2860

2961

62+
def transform_func(f, *transformer_ctors):
63+
module_cst = get_module_cst(f)
64+
context = types.SimpleNamespace()
65+
for transformer_ctor in transformer_ctors:
66+
orig_code = module_cst.code
67+
wrapper = cst.MetadataWrapper(module_cst)
68+
func_node = wrapper.module.body[0]
69+
replace = transformer_ctor(
70+
context=context, first_lineno=f.__code__.co_firstlineno - 1
71+
)
72+
logger.debug("[transformer] %s", replace.__class__.__name__)
73+
with replace.resolve(wrapper):
74+
new_func = func_node._visit_and_replace_children(replace)
75+
module_cst = wrapper.module.deep_replace(func_node, new_func)
76+
new_code = module_cst.code
77+
78+
diff = list(
79+
difflib.unified_diff(
80+
orig_code.splitlines(), # to this
81+
new_code.splitlines(), # delta from this
82+
lineterm="",
83+
)
84+
)
85+
logger.debug("[transformed code diff]\n\n%s", "\n" + "\n".join(diff))
86+
logger.debug("[final transformed code]\n\n%s", module_cst.code)
87+
88+
return module_cst
89+
90+
3091
def transform_cst(
3192
f, transformers: list[type(Transformer) | type(StrictTransformer)] = None
3293
):
3394
if transformers is None:
3495
return f
3596

36-
module_cst = get_module_cst(f)
37-
context = types.SimpleNamespace()
38-
for transformer in transformers:
39-
func_node = module_cst.body[0]
40-
replace = transformer(context)
41-
new_func = func_node._visit_and_replace_children(replace)
42-
module_cst = module_cst.deep_replace(func_node, new_func)
43-
44-
logger.debug("[transformed code]\n\n%s", module_cst.code)
97+
module_cst = transform_func(f, *transformers)
4598

4699
code = "\n" * (f.__code__.co_firstlineno - 1) + module_cst.code
47100
module_code_o = compile(code, f.__code__.co_filename, "exec")

0 commit comments

Comments
 (0)