1
1
import ast
2
+ import functools
2
3
from collections import deque
3
4
from contextlib import contextmanager
4
5
from typing import Deque , Iterable , List , Optional , Set
@@ -17,6 +18,93 @@ def _filter_dead_code(nodes: Iterable[ast.AST]) -> List[ast.AST]:
17
18
return new_nodes
18
19
19
20
21
+ @functools .singledispatch
22
+ def _optimize_operator_call ( # pylint: disable=unused-argument
23
+ fn : ast .AST , node : ast .Call
24
+ ) -> ast .AST :
25
+ return node
26
+
27
+
28
+ @_optimize_operator_call .register (ast .Attribute )
29
+ def _optimize_operator_call_attr ( # pylint: disable=too-many-return-statements
30
+ fn : ast .Attribute , node : ast .Call
31
+ ) -> ast .AST :
32
+ """Optimize calls to the Python `operator` module down to use the raw Python
33
+ operators.
34
+
35
+ Using Python operators directly will allow for more direct bytecode to be
36
+ emitted by the Python compiler and take advantage of any additional performance
37
+ improvements in future versions of Python."""
38
+ if isinstance (fn .value , ast .Name ) and fn .value .id == "operator" :
39
+ binop = {
40
+ "add" : ast .Add ,
41
+ "and_" : ast .BitAnd ,
42
+ "floordiv" : ast .FloorDiv ,
43
+ "lshift" : ast .LShift ,
44
+ "mod" : ast .Mod ,
45
+ "mul" : ast .Mult ,
46
+ "matmul" : ast .MatMult ,
47
+ "or_" : ast .BitOr ,
48
+ "pow" : ast .Pow ,
49
+ "rshift" : ast .RShift ,
50
+ "sub" : ast .Sub ,
51
+ "truediv" : ast .Div ,
52
+ "xor" : ast .BitXor ,
53
+ }.get (fn .attr )
54
+ if binop is not None :
55
+ arg1 , arg2 = node .args
56
+ assert len (node .args ) == 2
57
+ return ast .BinOp (arg1 , binop (), arg2 )
58
+
59
+ unaryop = {"not_" : ast .Not , "inv" : ast .Invert , "invert" : ast .Invert }.get (
60
+ fn .attr
61
+ )
62
+ if unaryop is not None :
63
+ arg = node .args [0 ]
64
+ assert len (node .args ) == 1
65
+ return ast .UnaryOp (unaryop (), arg )
66
+
67
+ compareop = {
68
+ "lt" : ast .Lt ,
69
+ "le" : ast .LtE ,
70
+ "eq" : ast .Eq ,
71
+ "ne" : ast .NotEq ,
72
+ "gt" : ast .Gt ,
73
+ "ge" : ast .GtE ,
74
+ "is_" : ast .Is ,
75
+ "is_not" : ast .IsNot ,
76
+ }.get (fn .attr )
77
+ if compareop is not None :
78
+ arg1 , arg2 = node .args
79
+ assert len (node .args ) == 2
80
+ return ast .Compare (arg1 , [compareop ()], [arg2 ])
81
+
82
+ if fn .attr == "contains" :
83
+ arg1 , arg2 = node .args
84
+ assert len (node .args ) == 2
85
+ return ast .Compare (arg2 , [ast .In ()], [arg1 ])
86
+
87
+ if fn .attr == "delitem" :
88
+ target , index = node .args
89
+ assert len (node .args ) == 2
90
+ return ast .Delete (
91
+ targets = [
92
+ ast .Subscript (
93
+ value = target , slice = ast .Index (value = index ), ctx = ast .Del ()
94
+ )
95
+ ]
96
+ )
97
+
98
+ if fn .attr == "getitem" :
99
+ target , index = node .args
100
+ assert len (node .args ) == 2
101
+ return ast .Subscript (
102
+ value = target , slice = ast .Index (value = index ), ctx = ast .Load ()
103
+ )
104
+
105
+ return node
106
+
107
+
20
108
class PythonASTOptimizer (ast .NodeTransformer ):
21
109
__slots__ = ("_global_ctx" ,)
22
110
@@ -37,6 +125,16 @@ def _global_context(self) -> Set[str]:
37
125
"""Return the current Python `global` context."""
38
126
return self ._global_ctx [- 1 ]
39
127
128
+ def visit_Call (self , node : ast .Call ) -> ast .AST :
129
+ """Eliminate most calls to Python's `operator` module in favor of using native
130
+ operators."""
131
+ new_node = self .generic_visit (node )
132
+ if isinstance (new_node , ast .Call ):
133
+ return ast .copy_location (
134
+ _optimize_operator_call (node .func , new_node ), new_node
135
+ )
136
+ return new_node
137
+
40
138
def visit_ExceptHandler (self , node : ast .ExceptHandler ) -> Optional [ast .AST ]:
41
139
"""Eliminate dead code from except handler bodies."""
42
140
new_node = self .generic_visit (node )
0 commit comments