|
1 | 1 | import ast |
| 2 | +import difflib |
2 | 3 | import enum |
3 | 4 | import inspect |
4 | 5 | import logging |
|
9 | 10 |
|
10 | 11 | import libcst as cst |
11 | 12 | from bytecode import ConcreteBytecode |
| 13 | +from libcst._position import CodeRange, CodePosition |
12 | 14 | from libcst.matchers import MatcherDecoratableTransformer |
| 15 | +from libcst.metadata import ( |
| 16 | + PositionProvider, |
| 17 | + ParentNodeProvider, |
| 18 | + QualifiedNameProvider, |
| 19 | + ExperimentalReentrantCodegenProvider, |
| 20 | +) |
13 | 21 |
|
14 | 22 | from mlir_utils.ast.util import get_module_cst, copy_func |
15 | 23 |
|
16 | 24 | logger = logging.getLogger(__name__) |
17 | 25 |
|
18 | 26 |
|
19 | 27 | class Transformer(MatcherDecoratableTransformer): |
20 | | - def __init__(self, context): |
| 28 | + METADATA_DEPENDENCIES = ( |
| 29 | + PositionProvider, |
| 30 | + ParentNodeProvider, |
| 31 | + QualifiedNameProvider, |
| 32 | + ExperimentalReentrantCodegenProvider, |
| 33 | + ) |
| 34 | + |
| 35 | + def __init__(self, context, first_lineno): |
21 | 36 | super().__init__() |
22 | 37 | self.context = context |
| 38 | + self.first_lineno = first_lineno |
| 39 | + |
| 40 | + def get_pos(self, node): |
| 41 | + pos = self.get_metadata(PositionProvider, node) |
| 42 | + return CodeRange( |
| 43 | + CodePosition(pos.start.line + self.first_lineno, pos.start.column), |
| 44 | + CodePosition(pos.end.line + self.first_lineno, pos.end.column), |
| 45 | + ) |
| 46 | + |
| 47 | + def get_parent(self, node): |
| 48 | + # NB: can only call this on "original nodes" |
| 49 | + return self.get_metadata(ParentNodeProvider, node) |
| 50 | + |
| 51 | + def get_code_snippet(self, node): |
| 52 | + return self.get_metadata( |
| 53 | + ExperimentalReentrantCodegenProvider, node |
| 54 | + ).get_original_statement_code() |
23 | 55 |
|
24 | 56 |
|
25 | 57 | class StrictTransformer(Transformer): |
26 | 58 | def visit_FunctionDef(self, node: cst.FunctionDef): |
27 | 59 | return False |
28 | 60 |
|
29 | 61 |
|
| 62 | +def transform_func(f, *transformer_ctors): |
| 63 | + module_cst = get_module_cst(f) |
| 64 | + context = types.SimpleNamespace() |
| 65 | + for transformer_ctor in transformer_ctors: |
| 66 | + orig_code = module_cst.code |
| 67 | + wrapper = cst.MetadataWrapper(module_cst) |
| 68 | + func_node = wrapper.module.body[0] |
| 69 | + replace = transformer_ctor( |
| 70 | + context=context, first_lineno=f.__code__.co_firstlineno - 1 |
| 71 | + ) |
| 72 | + logger.debug("[transformer] %s", replace.__class__.__name__) |
| 73 | + with replace.resolve(wrapper): |
| 74 | + new_func = func_node._visit_and_replace_children(replace) |
| 75 | + module_cst = wrapper.module.deep_replace(func_node, new_func) |
| 76 | + new_code = module_cst.code |
| 77 | + |
| 78 | + diff = list( |
| 79 | + difflib.unified_diff( |
| 80 | + orig_code.splitlines(), # to this |
| 81 | + new_code.splitlines(), # delta from this |
| 82 | + lineterm="", |
| 83 | + ) |
| 84 | + ) |
| 85 | + logger.debug("[transformed code diff]\n\n%s", "\n" + "\n".join(diff)) |
| 86 | + logger.debug("[final transformed code]\n\n%s", module_cst.code) |
| 87 | + |
| 88 | + return module_cst |
| 89 | + |
| 90 | + |
30 | 91 | def transform_cst( |
31 | 92 | f, transformers: list[type(Transformer) | type(StrictTransformer)] = None |
32 | 93 | ): |
33 | 94 | if transformers is None: |
34 | 95 | return f |
35 | 96 |
|
36 | | - module_cst = get_module_cst(f) |
37 | | - context = types.SimpleNamespace() |
38 | | - for transformer in transformers: |
39 | | - func_node = module_cst.body[0] |
40 | | - replace = transformer(context) |
41 | | - new_func = func_node._visit_and_replace_children(replace) |
42 | | - module_cst = module_cst.deep_replace(func_node, new_func) |
43 | | - |
44 | | - logger.debug("[transformed code]\n\n%s", module_cst.code) |
| 97 | + module_cst = transform_func(f, *transformers) |
45 | 98 |
|
46 | 99 | code = "\n" * (f.__code__.co_firstlineno - 1) + module_cst.code |
47 | 100 | module_code_o = compile(code, f.__code__.co_filename, "exec") |
|
0 commit comments