Skip to content

Commit fc4fba8

Browse files
authored
Optimize operator calls to their Python native operator equivalents (#748)
Fixes #754
1 parent 2e864d9 commit fc4fba8

File tree

2 files changed

+103
-2
lines changed

2 files changed

+103
-2
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
### Changed
10+
* Optimize calls to Python's `operator` module into their corresponding native operators (#754)
11+
912
### Fixed
10-
* Fix issue with `(count nil)` throwing an exception (#759).
13+
* Fix issue with `(count nil)` throwing an exception (#759)
1114

1215
## [v0.1.0b0]
1316
### Added
@@ -16,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1619
* Added support for Python 3.12 (#734)
1720
* Added a default reader conditional for the current platform (`windows`, `darwin`, `linux`, etc.) (#692)
1821
* Added support for `bencode` binary encoding (part of #412)
19-
* Ported nbb's nrepl-server module to basilisp (#412).
22+
* Ported nbb's nrepl-server module to basilisp (#412)
2023

2124
### Changed
2225
* Basilisp now supports PyTest 7.0+ (#660)

src/basilisp/lang/compiler/optimizer.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import functools
23
from collections import deque
34
from contextlib import contextmanager
45
from typing import Deque, Iterable, List, Optional, Set
@@ -17,6 +18,93 @@ def _filter_dead_code(nodes: Iterable[ast.AST]) -> List[ast.AST]:
1718
return new_nodes
1819

1920

21+
@functools.singledispatch
22+
def _optimize_operator_call( # pylint: disable=unused-argument
23+
fn: ast.AST, node: ast.Call
24+
) -> ast.AST:
25+
return node
26+
27+
28+
@_optimize_operator_call.register(ast.Attribute)
29+
def _optimize_operator_call_attr( # pylint: disable=too-many-return-statements
30+
fn: ast.Attribute, node: ast.Call
31+
) -> ast.AST:
32+
"""Optimize calls to the Python `operator` module down to use the raw Python
33+
operators.
34+
35+
Using Python operators directly will allow for more direct bytecode to be
36+
emitted by the Python compiler and take advantage of any additional performance
37+
improvements in future versions of Python."""
38+
if isinstance(fn.value, ast.Name) and fn.value.id == "operator":
39+
binop = {
40+
"add": ast.Add,
41+
"and_": ast.BitAnd,
42+
"floordiv": ast.FloorDiv,
43+
"lshift": ast.LShift,
44+
"mod": ast.Mod,
45+
"mul": ast.Mult,
46+
"matmul": ast.MatMult,
47+
"or_": ast.BitOr,
48+
"pow": ast.Pow,
49+
"rshift": ast.RShift,
50+
"sub": ast.Sub,
51+
"truediv": ast.Div,
52+
"xor": ast.BitXor,
53+
}.get(fn.attr)
54+
if binop is not None:
55+
arg1, arg2 = node.args
56+
assert len(node.args) == 2
57+
return ast.BinOp(arg1, binop(), arg2)
58+
59+
unaryop = {"not_": ast.Not, "inv": ast.Invert, "invert": ast.Invert}.get(
60+
fn.attr
61+
)
62+
if unaryop is not None:
63+
arg = node.args[0]
64+
assert len(node.args) == 1
65+
return ast.UnaryOp(unaryop(), arg)
66+
67+
compareop = {
68+
"lt": ast.Lt,
69+
"le": ast.LtE,
70+
"eq": ast.Eq,
71+
"ne": ast.NotEq,
72+
"gt": ast.Gt,
73+
"ge": ast.GtE,
74+
"is_": ast.Is,
75+
"is_not": ast.IsNot,
76+
}.get(fn.attr)
77+
if compareop is not None:
78+
arg1, arg2 = node.args
79+
assert len(node.args) == 2
80+
return ast.Compare(arg1, [compareop()], [arg2])
81+
82+
if fn.attr == "contains":
83+
arg1, arg2 = node.args
84+
assert len(node.args) == 2
85+
return ast.Compare(arg2, [ast.In()], [arg1])
86+
87+
if fn.attr == "delitem":
88+
target, index = node.args
89+
assert len(node.args) == 2
90+
return ast.Delete(
91+
targets=[
92+
ast.Subscript(
93+
value=target, slice=ast.Index(value=index), ctx=ast.Del()
94+
)
95+
]
96+
)
97+
98+
if fn.attr == "getitem":
99+
target, index = node.args
100+
assert len(node.args) == 2
101+
return ast.Subscript(
102+
value=target, slice=ast.Index(value=index), ctx=ast.Load()
103+
)
104+
105+
return node
106+
107+
20108
class PythonASTOptimizer(ast.NodeTransformer):
21109
__slots__ = ("_global_ctx",)
22110

@@ -37,6 +125,16 @@ def _global_context(self) -> Set[str]:
37125
"""Return the current Python `global` context."""
38126
return self._global_ctx[-1]
39127

128+
def visit_Call(self, node: ast.Call) -> ast.AST:
129+
"""Eliminate most calls to Python's `operator` module in favor of using native
130+
operators."""
131+
new_node = self.generic_visit(node)
132+
if isinstance(new_node, ast.Call):
133+
return ast.copy_location(
134+
_optimize_operator_call(node.func, new_node), new_node
135+
)
136+
return new_node
137+
40138
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> Optional[ast.AST]:
41139
"""Eliminate dead code from except handler bodies."""
42140
new_node = self.generic_visit(node)

0 commit comments

Comments
 (0)