Skip to content

Commit 2dc71e4

Browse files
committed
RF - add copy of ast file for copied codegen
1 parent 9a04815 commit 2dc71e4

File tree

2 files changed

+391
-46
lines changed

2 files changed

+391
-46
lines changed

nisext/ast.py

Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
ast
4+
~~~
5+
6+
The `ast` module helps Python applications to process trees of the Python
7+
abstract syntax grammar. The abstract syntax itself might change with
8+
each Python release; this module helps to find out programmatically what
9+
the current grammar looks like and allows modifications of it.
10+
11+
An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
12+
a flag to the `compile()` builtin function or by using the `parse()`
13+
function from this module. The result will be a tree of objects whose
14+
classes all inherit from `ast.AST`.
15+
16+
A modified abstract syntax tree can be compiled into a Python code object
17+
using the built-in `compile()` function.
18+
19+
Additionally various helper functions are provided that make working with
20+
the trees simpler. The main intention of the helper functions and this
21+
module in general is to provide an easy to use interface for libraries
22+
that work tightly with the python syntax (template engines for example).
23+
24+
25+
:copyright: Copyright 2008 by Armin Ronacher.
26+
:license: Python License.
27+
28+
From: http://dev.pocoo.org/hg/sandbox
29+
"""
30+
from _ast import *
31+
32+
33+
BOOLOP_SYMBOLS = {
34+
And: 'and',
35+
Or: 'or'
36+
}
37+
38+
BINOP_SYMBOLS = {
39+
Add: '+',
40+
Sub: '-',
41+
Mult: '*',
42+
Div: '/',
43+
FloorDiv: '//',
44+
Mod: '%',
45+
LShift: '<<',
46+
RShift: '>>',
47+
BitOr: '|',
48+
BitAnd: '&',
49+
BitXor: '^'
50+
}
51+
52+
CMPOP_SYMBOLS = {
53+
Eq: '==',
54+
Gt: '>',
55+
GtE: '>=',
56+
In: 'in',
57+
Is: 'is',
58+
IsNot: 'is not',
59+
Lt: '<',
60+
LtE: '<=',
61+
NotEq: '!=',
62+
NotIn: 'not in'
63+
}
64+
65+
UNARYOP_SYMBOLS = {
66+
Invert: '~',
67+
Not: 'not',
68+
UAdd: '+',
69+
USub: '-'
70+
}
71+
72+
ALL_SYMBOLS = {}
73+
ALL_SYMBOLS.update(BOOLOP_SYMBOLS)
74+
ALL_SYMBOLS.update(BINOP_SYMBOLS)
75+
ALL_SYMBOLS.update(CMPOP_SYMBOLS)
76+
ALL_SYMBOLS.update(UNARYOP_SYMBOLS)
77+
78+
79+
def parse(expr, filename='<unknown>', mode='exec'):
80+
"""Parse an expression into an AST node."""
81+
return compile(expr, filename, mode, PyCF_ONLY_AST)
82+
83+
84+
def literal_eval(node_or_string):
85+
"""Safe evaluate a literal. The string or node provided may include any
86+
of the following python structures: strings, numbers, tuples, lists,
87+
dicts, booleans or None.
88+
"""
89+
_safe_names = {'None': None, 'True': True, 'False': False}
90+
if isinstance(node_or_string, basestring):
91+
node_or_string = parse(node_or_string, mode='eval')
92+
if isinstance(node_or_string, Expression):
93+
node_or_string = node_or_string.body
94+
def _convert(node):
95+
if isinstance(node, Str):
96+
return node.s
97+
elif isinstance(node, Num):
98+
return node.n
99+
elif isinstance(node, Tuple):
100+
return tuple(map(_convert, node.elts))
101+
elif isinstance(node, List):
102+
return list(map(_convert, node.elts))
103+
elif isinstance(node, Dict):
104+
return dict((_convert(k), _convert(v)) for k, v
105+
in zip(node.keys, node.values))
106+
elif isinstance(node, Name):
107+
if node.id in _safe_names:
108+
return _safe_names[node.id]
109+
raise ValueError('malformed string')
110+
return _convert(node_or_string)
111+
112+
113+
def dump(node, annotate_fields=True, include_attributes=False):
114+
"""A very verbose representation of the node passed. This is useful for
115+
debugging purposes. Per default the returned string will show the names
116+
and the values for fields. This makes the code impossible to evaluate,
117+
if evaluation is wanted `annotate_fields` must be set to False.
118+
Attributes such as line numbers and column offsets are dumped by default.
119+
If this is wanted, `include_attributes` can be set to `True`.
120+
"""
121+
def _format(node):
122+
if isinstance(node, AST):
123+
fields = [(a, _format(b)) for a, b in iter_fields(node)]
124+
rv = '%s(%s' % (node.__class__.__name__, ', '.join(
125+
('%s=%s' % field for field in fields)
126+
if annotate_fields else
127+
(b for a, b in fields)
128+
))
129+
if include_attributes and node._attributes:
130+
rv += fields and ', ' or ' '
131+
rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
132+
for a in node._attributes)
133+
return rv + ')'
134+
elif isinstance(node, list):
135+
return '[%s]' % ', '.join(_format(x) for x in node)
136+
return repr(node)
137+
if not isinstance(node, AST):
138+
raise TypeError('expected AST, got %r' % node.__class__.__name__)
139+
return _format(node)
140+
141+
142+
def copy_location(new_node, old_node):
143+
"""Copy the source location hint (`lineno` and `col_offset`) from the
144+
old to the new node if possible and return the new one.
145+
"""
146+
for attr in 'lineno', 'col_offset':
147+
if attr in old_node._attributes and attr in new_node._attributes \
148+
and hasattr(old_node, attr):
149+
setattr(new_node, attr, getattr(old_node, attr))
150+
return new_node
151+
152+
153+
def fix_missing_locations(node):
154+
"""Some nodes require a line number and the column offset. Without that
155+
information the compiler will abort the compilation. Because it can be
156+
a dull task to add appropriate line numbers and column offsets when
157+
adding new nodes this function can help. It copies the line number and
158+
column offset of the parent node to the child nodes without this
159+
information.
160+
161+
Unlike `copy_location` this works recursive and won't touch nodes that
162+
already have a location information.
163+
"""
164+
def _fix(node, lineno, col_offset):
165+
if 'lineno' in node._attributes:
166+
if not hasattr(node, 'lineno'):
167+
node.lineno = lineno
168+
else:
169+
lineno = node.lineno
170+
if 'col_offset' in node._attributes:
171+
if not hasattr(node, 'col_offset'):
172+
node.col_offset = col_offset
173+
else:
174+
col_offset = node.col_offset
175+
for child in iter_child_nodes(node):
176+
_fix(child, lineno, col_offset)
177+
_fix(node, 1, 0)
178+
return node
179+
180+
181+
def increment_lineno(node, n=1):
182+
"""Increment the line numbers of all nodes by `n` if they have line number
183+
attributes. This is useful to "move code" to a different location in a
184+
file.
185+
"""
186+
if 'lineno' in node._attributes:
187+
node.lineno = getattr(node, 'lineno', 0) + n
188+
for child in walk(node):
189+
if 'lineno' in child._attributes:
190+
child.lineno = getattr(child, 'lineno', 0) + n
191+
return node
192+
193+
194+
def iter_fields(node):
195+
"""Iterate over all fields of a node, only yielding existing fields."""
196+
for field in node._fields:
197+
try:
198+
yield field, getattr(node, field)
199+
except AttributeError:
200+
pass
201+
202+
203+
def get_fields(node):
204+
"""Like `iter_fiels` but returns a dict."""
205+
return dict(iter_fields(node))
206+
207+
208+
def iter_child_nodes(node):
209+
"""Iterate over all child nodes or a node."""
210+
for name, field in iter_fields(node):
211+
if isinstance(field, AST):
212+
yield field
213+
elif isinstance(field, list):
214+
for item in field:
215+
if isinstance(item, AST):
216+
yield item
217+
218+
219+
def get_child_nodes(node):
220+
"""Like `iter_child_nodes` but returns a list."""
221+
return list(iter_child_nodes(node))
222+
223+
224+
def get_docstring(node, trim=True):
225+
"""Return the docstring for the given node or `None` if no docstring can
226+
be found. If the node provided does not accept docstrings a `TypeError`
227+
will be raised.
228+
"""
229+
if not isinstance(node, (FunctionDef, ClassDef, Module)):
230+
raise TypeError("%r can't have docstrings" % node.__class__.__name__)
231+
if node.body and isinstance(node.body[0], Expr) and \
232+
isinstance(node.body[0].value, Str):
233+
doc = node.body[0].value.s
234+
if trim:
235+
doc = trim_docstring(doc)
236+
return doc
237+
238+
239+
def trim_docstring(docstring):
240+
"""Trim a docstring. This should probably go into the inspect module."""
241+
lines = docstring.expandtabs().splitlines()
242+
243+
# Find minimum indentation of any non-blank lines after first line.
244+
from sys import maxint
245+
margin = maxint
246+
for line in lines[1:]:
247+
content = len(line.lstrip())
248+
if content:
249+
indent = len(line) - content
250+
margin = min(margin, indent)
251+
252+
# Remove indentation.
253+
if lines:
254+
lines[0] = lines[0].lstrip()
255+
if margin < maxint:
256+
for i in range(1, len(lines)):
257+
lines[i] = lines[i][margin:]
258+
259+
# Remove any trailing or leading blank lines.
260+
while lines and not lines[-1]:
261+
lines.pop()
262+
while lines and not lines[0]:
263+
lines.pop(0)
264+
return '\n'.join(lines)
265+
266+
267+
def get_symbol(operator):
268+
"""Return the symbol of the given operator node or node type."""
269+
if isinstance(operator, AST):
270+
operator = type(operator)
271+
try:
272+
return ALL_SYMBOLS[operator]
273+
except KeyError:
274+
raise LookupError('no known symbol for %r' % operator)
275+
276+
277+
def walk(node):
278+
"""Iterate over all nodes. This is useful if you only want to modify nodes
279+
in place and don't care about the context or the order the nodes are
280+
returned.
281+
"""
282+
from collections import deque
283+
todo = deque([node])
284+
while todo:
285+
node = todo.popleft()
286+
todo.extend(iter_child_nodes(node))
287+
yield node
288+
289+
290+
class NodeVisitor(object):
291+
"""Walks the abstract syntax tree and call visitor functions for every
292+
node found. The visitor functions may return values which will be
293+
forwarded by the `visit` method.
294+
295+
Per default the visitor functions for the nodes are ``'visit_'`` +
296+
class name of the node. So a `TryFinally` node visit function would
297+
be `visit_TryFinally`. This behavior can be changed by overriding
298+
the `get_visitor` function. If no visitor function exists for a node
299+
(return value `None`) the `generic_visit` visitor is used instead.
300+
301+
Don't use the `NodeVisitor` if you want to apply changes to nodes during
302+
traversing. For this a special visitor exists (`NodeTransformer`) that
303+
allows modifications.
304+
"""
305+
306+
def get_visitor(self, node):
307+
"""Return the visitor function for this node or `None` if no visitor
308+
exists for this node. In that case the generic visit function is
309+
used instead.
310+
"""
311+
method = 'visit_' + node.__class__.__name__
312+
return getattr(self, method, None)
313+
314+
def visit(self, node):
315+
"""Visit a node."""
316+
f = self.get_visitor(node)
317+
if f is not None:
318+
return f(node)
319+
return self.generic_visit(node)
320+
321+
def generic_visit(self, node):
322+
"""Called if no explicit visitor function exists for a node."""
323+
for field, value in iter_fields(node):
324+
if isinstance(value, list):
325+
for item in value:
326+
if isinstance(item, AST):
327+
self.visit(item)
328+
elif isinstance(value, AST):
329+
self.visit(value)
330+
331+
332+
class NodeTransformer(NodeVisitor):
333+
"""Walks the abstract syntax tree and allows modifications of nodes.
334+
335+
The `NodeTransformer` will walk the AST and use the return value of the
336+
visitor functions to replace or remove the old node. If the return
337+
value of the visitor function is `None` the node will be removed
338+
from the previous location otherwise it's replaced with the return
339+
value. The return value may be the original node in which case no
340+
replacement takes place.
341+
342+
Here an example transformer that rewrites all `foo` to `data['foo']`::
343+
344+
class RewriteName(NodeTransformer):
345+
346+
def visit_Name(self, node):
347+
return copy_location(Subscript(
348+
value=Name(id='data', ctx=Load()),
349+
slice=Index(value=Str(s=node.id)),
350+
ctx=node.ctx
351+
), node)
352+
353+
Keep in mind that if the node you're operating on has child nodes
354+
you must either transform the child nodes yourself or call the generic
355+
visit function for the node first.
356+
357+
Nodes that were part of a collection of statements (that applies to
358+
all statement nodes) may also return a list of nodes rather than just
359+
a single node.
360+
361+
Usually you use the transformer like this::
362+
363+
node = YourTransformer().visit(node)
364+
"""
365+
366+
def generic_visit(self, node):
367+
for field, old_value in iter_fields(node):
368+
old_value = getattr(node, field, None)
369+
if isinstance(old_value, list):
370+
new_values = []
371+
for value in old_value:
372+
if isinstance(value, AST):
373+
value = self.visit(value)
374+
if value is None:
375+
continue
376+
elif not isinstance(value, AST):
377+
new_values.extend(value)
378+
continue
379+
new_values.append(value)
380+
old_value[:] = new_values
381+
elif isinstance(old_value, AST):
382+
new_node = self.visit(old_value)
383+
if new_node is None:
384+
delattr(node, field)
385+
else:
386+
setattr(node, field, new_node)
387+
return node

0 commit comments

Comments
 (0)