55import logging
66import types
77from abc import ABC
8+ from dis import findlinestarts
89from opcode import opmap
910from types import CodeType
1011
11- import libcst as cst
1212from bytecode import ConcreteBytecode
13- from libcst ._position import CodeRange , CodePosition
14- from libcst .matchers import MatcherDecoratableTransformer
15- from libcst .metadata import (
16- PositionProvider ,
17- ParentNodeProvider ,
18- QualifiedNameProvider ,
19- ExperimentalReentrantCodegenProvider ,
20- )
2113
2214from mlir_utils .ast .util import get_module_cst , copy_func
2315
2416logger = logging .getLogger (__name__ )
2517
2618
27- class Transformer (MatcherDecoratableTransformer ):
28- METADATA_DEPENDENCIES = (
29- PositionProvider ,
30- ParentNodeProvider ,
31- QualifiedNameProvider ,
32- ExperimentalReentrantCodegenProvider ,
33- )
34-
19+ class Transformer (ast .NodeTransformer ):
3520 def __init__ (self , context , first_lineno ):
3621 super ().__init__ ()
3722 self .context = context
3823 self .first_lineno = first_lineno
3924
40- def get_pos (self , node ):
41- pos = self .get_metadata (PositionProvider , node )
42- return CodeRange (
43- CodePosition (pos .start .line + self .first_lineno , pos .start .column ),
44- CodePosition (pos .end .line + self .first_lineno , pos .end .column ),
45- )
46-
47- def get_parent (self , node ):
48- # NB: can only call this on "original nodes"
49- return self .get_metadata (ParentNodeProvider , node )
50-
51- def get_code_snippet (self , node ):
52- return self .get_metadata (
53- ExperimentalReentrantCodegenProvider , node
54- ).get_original_statement_code ()
55-
5625
5726class StrictTransformer (Transformer ):
58- def visit_FunctionDef (self , node : cst .FunctionDef ):
59- return False
27+ def visit_FunctionDef (self , node : ast .FunctionDef ):
28+ return node
6029
6130
62- def transform_func (f , * transformer_ctors ):
63- module_cst = get_module_cst (f )
31+ def transform_func (f , * transformer_ctors : type ( Transformer ) ):
32+ module = get_module_cst (f )
6433 context = types .SimpleNamespace ()
6534 for transformer_ctor in transformer_ctors :
66- orig_code = module_cst .code
67- wrapper = cst .MetadataWrapper (module_cst )
68- func_node = wrapper .module .body [0 ]
35+ orig_code = ast .unparse (module )
36+ func_node = module .body [0 ]
6937 replace = transformer_ctor (
7038 context = context , first_lineno = f .__code__ .co_firstlineno - 1
7139 )
7240 logger .debug ("[transformer] %s" , replace .__class__ .__name__ )
73- with replace .resolve (wrapper ):
74- new_func = func_node ._visit_and_replace_children (replace )
75- module_cst = wrapper .module .deep_replace (func_node , new_func )
76- new_code = module_cst .code
41+ func_node = replace .generic_visit (func_node )
42+ new_code = ast .unparse (func_node )
7743
7844 diff = list (
7945 difflib .unified_diff (
@@ -82,28 +48,35 @@ def transform_func(f, *transformer_ctors):
8248 lineterm = "" ,
8349 )
8450 )
85- logger .debug ("[transformed code diff]\n \n %s" , "\n " + "\n " .join (diff ))
86- logger .debug ("[final transformed code]\n \n %s" , module_cst .code )
51+ logger .debug ("[transformed code diff]\n %s" , "\n " + "\n " .join (diff ))
52+ logger .debug ("[transformed code]\n %s" , new_code )
53+ module .body [0 ] = func_node
54+
55+ logger .debug ("[final transformed code]\n \n %s" , new_code )
8756
88- return module_cst
57+ return module
8958
9059
91- def transform_cst (
60+ def transform_ast (
9261 f , transformers : list [type (Transformer ) | type (StrictTransformer )] = None
9362):
9463 if transformers is None :
9564 return f
9665
97- module_cst = transform_func (f , * transformers )
98-
99- code = " \n " * ( f .__code__ .co_firstlineno - 1 ) + module_cst . code
100- module_code_o = compile (code , f .__code__ .co_filename , "exec" )
66+ module = transform_func (f , * transformers )
67+ module = ast . fix_missing_locations ( module )
68+ module = ast . increment_lineno ( module , f .__code__ .co_firstlineno - 1 )
69+ module_code_o = compile (module , f .__code__ .co_filename , "exec" )
10170 new_f_code_o = next (
10271 c
10372 for c in module_code_o .co_consts
10473 if type (c ) is CodeType and c .co_name == f .__name__
10574 )
106-
75+ n_lines = len (inspect .getsource (f ).splitlines ())
76+ line_starts = list (findlinestarts (new_f_code_o ))
77+ assert (
78+ line_starts [- 1 ][1 ] - line_starts [0 ][1 ] == n_lines - 1
79+ ), f"something went wrong with the line numbers for the rewritten/canonicalized function"
10780 return copy_func (f , new_f_code_o )
10881
10982
@@ -156,7 +129,7 @@ def bytecode_patchers(self) -> list[BytecodePatcher]:
156129
157130def canonicalize (* , using : Canonicalizer ):
158131 def wrapper (f ):
159- f = transform_cst (f , using .cst_transformers )
132+ f = transform_ast (f , using .cst_transformers )
160133 f = patch_bytecode (f , using .bytecode_patchers )
161134 return f
162135
0 commit comments