|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | + ast |
| 4 | + ~~~ |
| 5 | +
|
| 6 | + The `ast` module helps Python applications to process trees of the Python |
| 7 | + abstract syntax grammar. The abstract syntax itself might change with |
| 8 | + each Python release; this module helps to find out programmatically what |
| 9 | + the current grammar looks like and allows modifications of it. |
| 10 | +
|
| 11 | + An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as |
| 12 | + a flag to the `compile()` builtin function or by using the `parse()` |
| 13 | + function from this module. The result will be a tree of objects whose |
| 14 | + classes all inherit from `ast.AST`. |
| 15 | +
|
| 16 | + A modified abstract syntax tree can be compiled into a Python code object |
| 17 | + using the built-in `compile()` function. |
| 18 | +
|
| 19 | + Additionally various helper functions are provided that make working with |
| 20 | + the trees simpler. The main intention of the helper functions and this |
| 21 | + module in general is to provide an easy to use interface for libraries |
| 22 | + that work tightly with the python syntax (template engines for example). |
| 23 | +
|
| 24 | +
|
| 25 | + :copyright: Copyright 2008 by Armin Ronacher. |
| 26 | + :license: Python License. |
| 27 | +
|
| 28 | + From: http://dev.pocoo.org/hg/sandbox |
| 29 | +""" |
| 30 | +from _ast import * |
| 31 | + |
| 32 | + |
| 33 | +BOOLOP_SYMBOLS = { |
| 34 | + And: 'and', |
| 35 | + Or: 'or' |
| 36 | +} |
| 37 | + |
| 38 | +BINOP_SYMBOLS = { |
| 39 | + Add: '+', |
| 40 | + Sub: '-', |
| 41 | + Mult: '*', |
| 42 | + Div: '/', |
| 43 | + FloorDiv: '//', |
| 44 | + Mod: '%', |
| 45 | + LShift: '<<', |
| 46 | + RShift: '>>', |
| 47 | + BitOr: '|', |
| 48 | + BitAnd: '&', |
| 49 | + BitXor: '^' |
| 50 | +} |
| 51 | + |
| 52 | +CMPOP_SYMBOLS = { |
| 53 | + Eq: '==', |
| 54 | + Gt: '>', |
| 55 | + GtE: '>=', |
| 56 | + In: 'in', |
| 57 | + Is: 'is', |
| 58 | + IsNot: 'is not', |
| 59 | + Lt: '<', |
| 60 | + LtE: '<=', |
| 61 | + NotEq: '!=', |
| 62 | + NotIn: 'not in' |
| 63 | +} |
| 64 | + |
| 65 | +UNARYOP_SYMBOLS = { |
| 66 | + Invert: '~', |
| 67 | + Not: 'not', |
| 68 | + UAdd: '+', |
| 69 | + USub: '-' |
| 70 | +} |
| 71 | + |
| 72 | +ALL_SYMBOLS = {} |
| 73 | +ALL_SYMBOLS.update(BOOLOP_SYMBOLS) |
| 74 | +ALL_SYMBOLS.update(BINOP_SYMBOLS) |
| 75 | +ALL_SYMBOLS.update(CMPOP_SYMBOLS) |
| 76 | +ALL_SYMBOLS.update(UNARYOP_SYMBOLS) |
| 77 | + |
| 78 | + |
| 79 | +def parse(expr, filename='<unknown>', mode='exec'): |
| 80 | + """Parse an expression into an AST node.""" |
| 81 | + return compile(expr, filename, mode, PyCF_ONLY_AST) |
| 82 | + |
| 83 | + |
| 84 | +def literal_eval(node_or_string): |
| 85 | + """Safe evaluate a literal. The string or node provided may include any |
| 86 | + of the following python structures: strings, numbers, tuples, lists, |
| 87 | + dicts, booleans or None. |
| 88 | + """ |
| 89 | + _safe_names = {'None': None, 'True': True, 'False': False} |
| 90 | + if isinstance(node_or_string, basestring): |
| 91 | + node_or_string = parse(node_or_string, mode='eval') |
| 92 | + if isinstance(node_or_string, Expression): |
| 93 | + node_or_string = node_or_string.body |
| 94 | + def _convert(node): |
| 95 | + if isinstance(node, Str): |
| 96 | + return node.s |
| 97 | + elif isinstance(node, Num): |
| 98 | + return node.n |
| 99 | + elif isinstance(node, Tuple): |
| 100 | + return tuple(map(_convert, node.elts)) |
| 101 | + elif isinstance(node, List): |
| 102 | + return list(map(_convert, node.elts)) |
| 103 | + elif isinstance(node, Dict): |
| 104 | + return dict((_convert(k), _convert(v)) for k, v |
| 105 | + in zip(node.keys, node.values)) |
| 106 | + elif isinstance(node, Name): |
| 107 | + if node.id in _safe_names: |
| 108 | + return _safe_names[node.id] |
| 109 | + raise ValueError('malformed string') |
| 110 | + return _convert(node_or_string) |
| 111 | + |
| 112 | + |
| 113 | +def dump(node, annotate_fields=True, include_attributes=False): |
| 114 | + """A very verbose representation of the node passed. This is useful for |
| 115 | + debugging purposes. Per default the returned string will show the names |
| 116 | + and the values for fields. This makes the code impossible to evaluate, |
| 117 | + if evaluation is wanted `annotate_fields` must be set to False. |
| 118 | + Attributes such as line numbers and column offsets are dumped by default. |
| 119 | + If this is wanted, `include_attributes` can be set to `True`. |
| 120 | + """ |
| 121 | + def _format(node): |
| 122 | + if isinstance(node, AST): |
| 123 | + fields = [(a, _format(b)) for a, b in iter_fields(node)] |
| 124 | + rv = '%s(%s' % (node.__class__.__name__, ', '.join( |
| 125 | + ('%s=%s' % field for field in fields) |
| 126 | + if annotate_fields else |
| 127 | + (b for a, b in fields) |
| 128 | + )) |
| 129 | + if include_attributes and node._attributes: |
| 130 | + rv += fields and ', ' or ' ' |
| 131 | + rv += ', '.join('%s=%s' % (a, _format(getattr(node, a))) |
| 132 | + for a in node._attributes) |
| 133 | + return rv + ')' |
| 134 | + elif isinstance(node, list): |
| 135 | + return '[%s]' % ', '.join(_format(x) for x in node) |
| 136 | + return repr(node) |
| 137 | + if not isinstance(node, AST): |
| 138 | + raise TypeError('expected AST, got %r' % node.__class__.__name__) |
| 139 | + return _format(node) |
| 140 | + |
| 141 | + |
| 142 | +def copy_location(new_node, old_node): |
| 143 | + """Copy the source location hint (`lineno` and `col_offset`) from the |
| 144 | + old to the new node if possible and return the new one. |
| 145 | + """ |
| 146 | + for attr in 'lineno', 'col_offset': |
| 147 | + if attr in old_node._attributes and attr in new_node._attributes \ |
| 148 | + and hasattr(old_node, attr): |
| 149 | + setattr(new_node, attr, getattr(old_node, attr)) |
| 150 | + return new_node |
| 151 | + |
| 152 | + |
| 153 | +def fix_missing_locations(node): |
| 154 | + """Some nodes require a line number and the column offset. Without that |
| 155 | + information the compiler will abort the compilation. Because it can be |
| 156 | + a dull task to add appropriate line numbers and column offsets when |
| 157 | + adding new nodes this function can help. It copies the line number and |
| 158 | + column offset of the parent node to the child nodes without this |
| 159 | + information. |
| 160 | +
|
| 161 | + Unlike `copy_location` this works recursive and won't touch nodes that |
| 162 | + already have a location information. |
| 163 | + """ |
| 164 | + def _fix(node, lineno, col_offset): |
| 165 | + if 'lineno' in node._attributes: |
| 166 | + if not hasattr(node, 'lineno'): |
| 167 | + node.lineno = lineno |
| 168 | + else: |
| 169 | + lineno = node.lineno |
| 170 | + if 'col_offset' in node._attributes: |
| 171 | + if not hasattr(node, 'col_offset'): |
| 172 | + node.col_offset = col_offset |
| 173 | + else: |
| 174 | + col_offset = node.col_offset |
| 175 | + for child in iter_child_nodes(node): |
| 176 | + _fix(child, lineno, col_offset) |
| 177 | + _fix(node, 1, 0) |
| 178 | + return node |
| 179 | + |
| 180 | + |
| 181 | +def increment_lineno(node, n=1): |
| 182 | + """Increment the line numbers of all nodes by `n` if they have line number |
| 183 | + attributes. This is useful to "move code" to a different location in a |
| 184 | + file. |
| 185 | + """ |
| 186 | + if 'lineno' in node._attributes: |
| 187 | + node.lineno = getattr(node, 'lineno', 0) + n |
| 188 | + for child in walk(node): |
| 189 | + if 'lineno' in child._attributes: |
| 190 | + child.lineno = getattr(child, 'lineno', 0) + n |
| 191 | + return node |
| 192 | + |
| 193 | + |
| 194 | +def iter_fields(node): |
| 195 | + """Iterate over all fields of a node, only yielding existing fields.""" |
| 196 | + for field in node._fields: |
| 197 | + try: |
| 198 | + yield field, getattr(node, field) |
| 199 | + except AttributeError: |
| 200 | + pass |
| 201 | + |
| 202 | + |
| 203 | +def get_fields(node): |
| 204 | + """Like `iter_fiels` but returns a dict.""" |
| 205 | + return dict(iter_fields(node)) |
| 206 | + |
| 207 | + |
| 208 | +def iter_child_nodes(node): |
| 209 | + """Iterate over all child nodes or a node.""" |
| 210 | + for name, field in iter_fields(node): |
| 211 | + if isinstance(field, AST): |
| 212 | + yield field |
| 213 | + elif isinstance(field, list): |
| 214 | + for item in field: |
| 215 | + if isinstance(item, AST): |
| 216 | + yield item |
| 217 | + |
| 218 | + |
| 219 | +def get_child_nodes(node): |
| 220 | + """Like `iter_child_nodes` but returns a list.""" |
| 221 | + return list(iter_child_nodes(node)) |
| 222 | + |
| 223 | + |
| 224 | +def get_docstring(node, trim=True): |
| 225 | + """Return the docstring for the given node or `None` if no docstring can |
| 226 | + be found. If the node provided does not accept docstrings a `TypeError` |
| 227 | + will be raised. |
| 228 | + """ |
| 229 | + if not isinstance(node, (FunctionDef, ClassDef, Module)): |
| 230 | + raise TypeError("%r can't have docstrings" % node.__class__.__name__) |
| 231 | + if node.body and isinstance(node.body[0], Expr) and \ |
| 232 | + isinstance(node.body[0].value, Str): |
| 233 | + doc = node.body[0].value.s |
| 234 | + if trim: |
| 235 | + doc = trim_docstring(doc) |
| 236 | + return doc |
| 237 | + |
| 238 | + |
| 239 | +def trim_docstring(docstring): |
| 240 | + """Trim a docstring. This should probably go into the inspect module.""" |
| 241 | + lines = docstring.expandtabs().splitlines() |
| 242 | + |
| 243 | + # Find minimum indentation of any non-blank lines after first line. |
| 244 | + from sys import maxint |
| 245 | + margin = maxint |
| 246 | + for line in lines[1:]: |
| 247 | + content = len(line.lstrip()) |
| 248 | + if content: |
| 249 | + indent = len(line) - content |
| 250 | + margin = min(margin, indent) |
| 251 | + |
| 252 | + # Remove indentation. |
| 253 | + if lines: |
| 254 | + lines[0] = lines[0].lstrip() |
| 255 | + if margin < maxint: |
| 256 | + for i in range(1, len(lines)): |
| 257 | + lines[i] = lines[i][margin:] |
| 258 | + |
| 259 | + # Remove any trailing or leading blank lines. |
| 260 | + while lines and not lines[-1]: |
| 261 | + lines.pop() |
| 262 | + while lines and not lines[0]: |
| 263 | + lines.pop(0) |
| 264 | + return '\n'.join(lines) |
| 265 | + |
| 266 | + |
| 267 | +def get_symbol(operator): |
| 268 | + """Return the symbol of the given operator node or node type.""" |
| 269 | + if isinstance(operator, AST): |
| 270 | + operator = type(operator) |
| 271 | + try: |
| 272 | + return ALL_SYMBOLS[operator] |
| 273 | + except KeyError: |
| 274 | + raise LookupError('no known symbol for %r' % operator) |
| 275 | + |
| 276 | + |
| 277 | +def walk(node): |
| 278 | + """Iterate over all nodes. This is useful if you only want to modify nodes |
| 279 | + in place and don't care about the context or the order the nodes are |
| 280 | + returned. |
| 281 | + """ |
| 282 | + from collections import deque |
| 283 | + todo = deque([node]) |
| 284 | + while todo: |
| 285 | + node = todo.popleft() |
| 286 | + todo.extend(iter_child_nodes(node)) |
| 287 | + yield node |
| 288 | + |
| 289 | + |
| 290 | +class NodeVisitor(object): |
| 291 | + """Walks the abstract syntax tree and call visitor functions for every |
| 292 | + node found. The visitor functions may return values which will be |
| 293 | + forwarded by the `visit` method. |
| 294 | +
|
| 295 | + Per default the visitor functions for the nodes are ``'visit_'`` + |
| 296 | + class name of the node. So a `TryFinally` node visit function would |
| 297 | + be `visit_TryFinally`. This behavior can be changed by overriding |
| 298 | + the `get_visitor` function. If no visitor function exists for a node |
| 299 | + (return value `None`) the `generic_visit` visitor is used instead. |
| 300 | +
|
| 301 | + Don't use the `NodeVisitor` if you want to apply changes to nodes during |
| 302 | + traversing. For this a special visitor exists (`NodeTransformer`) that |
| 303 | + allows modifications. |
| 304 | + """ |
| 305 | + |
| 306 | + def get_visitor(self, node): |
| 307 | + """Return the visitor function for this node or `None` if no visitor |
| 308 | + exists for this node. In that case the generic visit function is |
| 309 | + used instead. |
| 310 | + """ |
| 311 | + method = 'visit_' + node.__class__.__name__ |
| 312 | + return getattr(self, method, None) |
| 313 | + |
| 314 | + def visit(self, node): |
| 315 | + """Visit a node.""" |
| 316 | + f = self.get_visitor(node) |
| 317 | + if f is not None: |
| 318 | + return f(node) |
| 319 | + return self.generic_visit(node) |
| 320 | + |
| 321 | + def generic_visit(self, node): |
| 322 | + """Called if no explicit visitor function exists for a node.""" |
| 323 | + for field, value in iter_fields(node): |
| 324 | + if isinstance(value, list): |
| 325 | + for item in value: |
| 326 | + if isinstance(item, AST): |
| 327 | + self.visit(item) |
| 328 | + elif isinstance(value, AST): |
| 329 | + self.visit(value) |
| 330 | + |
| 331 | + |
| 332 | +class NodeTransformer(NodeVisitor): |
| 333 | + """Walks the abstract syntax tree and allows modifications of nodes. |
| 334 | +
|
| 335 | + The `NodeTransformer` will walk the AST and use the return value of the |
| 336 | + visitor functions to replace or remove the old node. If the return |
| 337 | + value of the visitor function is `None` the node will be removed |
| 338 | + from the previous location otherwise it's replaced with the return |
| 339 | + value. The return value may be the original node in which case no |
| 340 | + replacement takes place. |
| 341 | +
|
| 342 | + Here an example transformer that rewrites all `foo` to `data['foo']`:: |
| 343 | +
|
| 344 | + class RewriteName(NodeTransformer): |
| 345 | +
|
| 346 | + def visit_Name(self, node): |
| 347 | + return copy_location(Subscript( |
| 348 | + value=Name(id='data', ctx=Load()), |
| 349 | + slice=Index(value=Str(s=node.id)), |
| 350 | + ctx=node.ctx |
| 351 | + ), node) |
| 352 | +
|
| 353 | + Keep in mind that if the node you're operating on has child nodes |
| 354 | + you must either transform the child nodes yourself or call the generic |
| 355 | + visit function for the node first. |
| 356 | +
|
| 357 | + Nodes that were part of a collection of statements (that applies to |
| 358 | + all statement nodes) may also return a list of nodes rather than just |
| 359 | + a single node. |
| 360 | +
|
| 361 | + Usually you use the transformer like this:: |
| 362 | +
|
| 363 | + node = YourTransformer().visit(node) |
| 364 | + """ |
| 365 | + |
| 366 | + def generic_visit(self, node): |
| 367 | + for field, old_value in iter_fields(node): |
| 368 | + old_value = getattr(node, field, None) |
| 369 | + if isinstance(old_value, list): |
| 370 | + new_values = [] |
| 371 | + for value in old_value: |
| 372 | + if isinstance(value, AST): |
| 373 | + value = self.visit(value) |
| 374 | + if value is None: |
| 375 | + continue |
| 376 | + elif not isinstance(value, AST): |
| 377 | + new_values.extend(value) |
| 378 | + continue |
| 379 | + new_values.append(value) |
| 380 | + old_value[:] = new_values |
| 381 | + elif isinstance(old_value, AST): |
| 382 | + new_node = self.visit(old_value) |
| 383 | + if new_node is None: |
| 384 | + delattr(node, field) |
| 385 | + else: |
| 386 | + setattr(node, field, new_node) |
| 387 | + return node |
0 commit comments