Skip to content

Commit e184645

Browse files
committed
rewrite ifs to be withs instead ftw
1 parent a1f449d commit e184645

File tree

7 files changed

+4327
-1728
lines changed

7 files changed

+4327
-1728
lines changed

mlir_utils/ast/canonicalize.py

Lines changed: 27 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,75 +5,41 @@
55
import logging
66
import types
77
from abc import ABC
8+
from dis import findlinestarts
89
from opcode import opmap
910
from types import CodeType
1011

11-
import libcst as cst
1212
from bytecode import ConcreteBytecode
13-
from libcst._position import CodeRange, CodePosition
14-
from libcst.matchers import MatcherDecoratableTransformer
15-
from libcst.metadata import (
16-
PositionProvider,
17-
ParentNodeProvider,
18-
QualifiedNameProvider,
19-
ExperimentalReentrantCodegenProvider,
20-
)
2113

2214
from mlir_utils.ast.util import get_module_cst, copy_func
2315

2416
logger = logging.getLogger(__name__)
2517

2618

27-
class Transformer(MatcherDecoratableTransformer):
28-
METADATA_DEPENDENCIES = (
29-
PositionProvider,
30-
ParentNodeProvider,
31-
QualifiedNameProvider,
32-
ExperimentalReentrantCodegenProvider,
33-
)
34-
19+
class Transformer(ast.NodeTransformer):
3520
def __init__(self, context, first_lineno):
3621
super().__init__()
3722
self.context = context
3823
self.first_lineno = first_lineno
3924

40-
def get_pos(self, node):
41-
pos = self.get_metadata(PositionProvider, node)
42-
return CodeRange(
43-
CodePosition(pos.start.line + self.first_lineno, pos.start.column),
44-
CodePosition(pos.end.line + self.first_lineno, pos.end.column),
45-
)
46-
47-
def get_parent(self, node):
48-
# NB: can only call this on "original nodes"
49-
return self.get_metadata(ParentNodeProvider, node)
50-
51-
def get_code_snippet(self, node):
52-
return self.get_metadata(
53-
ExperimentalReentrantCodegenProvider, node
54-
).get_original_statement_code()
55-
5625

5726
class StrictTransformer(Transformer):
58-
def visit_FunctionDef(self, node: cst.FunctionDef):
59-
return False
27+
def visit_FunctionDef(self, node: ast.FunctionDef):
28+
return node
6029

6130

62-
def transform_func(f, *transformer_ctors):
63-
module_cst = get_module_cst(f)
31+
def transform_func(f, *transformer_ctors: type(Transformer)):
32+
module = get_module_cst(f)
6433
context = types.SimpleNamespace()
6534
for transformer_ctor in transformer_ctors:
66-
orig_code = module_cst.code
67-
wrapper = cst.MetadataWrapper(module_cst)
68-
func_node = wrapper.module.body[0]
35+
orig_code = ast.unparse(module)
36+
func_node = module.body[0]
6937
replace = transformer_ctor(
7038
context=context, first_lineno=f.__code__.co_firstlineno - 1
7139
)
7240
logger.debug("[transformer] %s", replace.__class__.__name__)
73-
with replace.resolve(wrapper):
74-
new_func = func_node._visit_and_replace_children(replace)
75-
module_cst = wrapper.module.deep_replace(func_node, new_func)
76-
new_code = module_cst.code
41+
func_node = replace.generic_visit(func_node)
42+
new_code = ast.unparse(func_node)
7743

7844
diff = list(
7945
difflib.unified_diff(
@@ -82,28 +48,35 @@ def transform_func(f, *transformer_ctors):
8248
lineterm="",
8349
)
8450
)
85-
logger.debug("[transformed code diff]\n\n%s", "\n" + "\n".join(diff))
86-
logger.debug("[final transformed code]\n\n%s", module_cst.code)
51+
logger.debug("[transformed code diff]\n%s", "\n" + "\n".join(diff))
52+
logger.debug("[transformed code]\n%s", new_code)
53+
module.body[0] = func_node
54+
55+
logger.debug("[final transformed code]\n\n%s", new_code)
8756

88-
return module_cst
57+
return module
8958

9059

91-
def transform_cst(
60+
def transform_ast(
9261
f, transformers: list[type(Transformer) | type(StrictTransformer)] = None
9362
):
9463
if transformers is None:
9564
return f
9665

97-
module_cst = transform_func(f, *transformers)
98-
99-
code = "\n" * (f.__code__.co_firstlineno - 1) + module_cst.code
100-
module_code_o = compile(code, f.__code__.co_filename, "exec")
66+
module = transform_func(f, *transformers)
67+
module = ast.fix_missing_locations(module)
68+
module = ast.increment_lineno(module, f.__code__.co_firstlineno - 1)
69+
module_code_o = compile(module, f.__code__.co_filename, "exec")
10170
new_f_code_o = next(
10271
c
10372
for c in module_code_o.co_consts
10473
if type(c) is CodeType and c.co_name == f.__name__
10574
)
106-
75+
n_lines = len(inspect.getsource(f).splitlines())
76+
line_starts = list(findlinestarts(new_f_code_o))
77+
assert (
78+
line_starts[-1][1] - line_starts[0][1] == n_lines - 1
79+
), f"something went wrong with the line numbers for the rewritten/canonicalized function"
10780
return copy_func(f, new_f_code_o)
10881

10982

@@ -156,7 +129,7 @@ def bytecode_patchers(self) -> list[BytecodePatcher]:
156129

157130
def canonicalize(*, using: Canonicalizer):
158131
def wrapper(f):
159-
f = transform_cst(f, using.cst_transformers)
132+
f = transform_ast(f, using.cst_transformers)
160133
f = patch_bytecode(f, using.bytecode_patchers)
161134
return f
162135

mlir_utils/ast/util.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
1+
import ast
12
import functools
23
import inspect
34
import types
45
from textwrap import dedent
56

6-
import libcst as cst
7+
8+
def set_lineno(node, n=1):
9+
for child in ast.walk(node):
10+
child.lineno = n
11+
child.end_lineno = n
12+
return node
713

814

915
def ast_call(name, args=None, keywords=None):
1016
if keywords is None:
1117
keywords = []
1218
if args is None:
1319
args = []
14-
call = cst.Call(
15-
func=cst.Name(value=name),
16-
args=args + keywords,
20+
call = ast.Call(
21+
func=ast.Name(name, ctx=ast.Load()),
22+
args=args,
23+
keywords=keywords,
1724
)
1825
return call
1926

2027

2128
def get_module_cst(f):
2229
f_src = dedent(inspect.getsource(f))
23-
tree = cst.parse_module(f_src)
30+
# tree = cst.parse_module(f_src)
31+
tree = ast.parse(f_src)
2432
assert isinstance(
25-
tree.body[0], cst.FunctionDef
33+
tree.body[0], ast.FunctionDef
2634
), f"unexpected ast node {tree.body[0]}"
2735
return tree
2836

0 commit comments

Comments
 (0)