Skip to content

Commit f903eb4

Browse files
committed
fold unary & binary complex constant expressions in codegen
1 parent d7672e5 commit f903eb4

File tree

7 files changed

+169
-16
lines changed

7 files changed

+169
-16
lines changed

Include/cpython/compile.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define PyCF_ALLOW_TOP_LEVEL_AWAIT 0x2000
2121
#define PyCF_ALLOW_INCOMPLETE_INPUT 0x4000
2222
#define PyCF_OPTIMIZED_AST (0x8000 | PyCF_ONLY_AST)
23+
#define PyCF_DONT_OPTIMIZE_AST 0x10000
2324
#define PyCF_COMPILE_MASK (PyCF_ONLY_AST | PyCF_ALLOW_TOP_LEVEL_AWAIT | \
2425
PyCF_TYPE_COMMENTS | PyCF_DONT_IMPLY_DEDENT | \
2526
PyCF_ALLOW_INCOMPLETE_INPUT | PyCF_OPTIMIZED_AST)

Lib/test/support/bytecode_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def check_instructions(self, insts):
138138
@unittest.skipIf(_testinternalcapi is None, "requires _testinternalcapi")
139139
class CodegenTestCase(CompilationStepTestCase):
140140

141-
def generate_code(self, ast):
142-
insts, _ = _testinternalcapi.compiler_codegen(ast, "my_file.py", 0)
141+
def generate_code(self, ast, optimize_ast=True):
142+
insts, _ = _testinternalcapi.compiler_codegen(ast, "my_file.py", 0, 0, optimize_ast)
143143
return insts
144144

145145

Lib/test/test_compiler_codegen.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ def assertInstructionsMatch_recursive(self, insts, expected_insts):
1515
for n_insts, n_expected in zip(insts.get_nested(), expected_nested):
1616
self.assertInstructionsMatch_recursive(n_insts, n_expected)
1717

18-
def codegen_test(self, snippet, expected_insts):
18+
def codegen_test(self, snippet, expected_insts, optimize_ast=True):
1919
import ast
2020
a = ast.parse(snippet, "my_file.py", "exec")
21-
insts = self.generate_code(a)
21+
insts = self.generate_code(a, optimize_ast=optimize_ast)
2222
self.assertInstructionsMatch_recursive(insts, expected_insts)
2323

2424
def test_if_expression(self):
@@ -157,3 +157,41 @@ def test_syntax_error__return_not_in_function(self):
157157
self.assertIsNone(cm.exception.text)
158158
self.assertEqual(cm.exception.offset, 1)
159159
self.assertEqual(cm.exception.end_offset, 10)
160+
161+
def test_dont_optimize_ast_before_codegen(self):
162+
snippet = "1+2"
163+
unoptimized = [
164+
('RESUME', 0, 0),
165+
('LOAD_SMALL_INT', 1, 0),
166+
('LOAD_SMALL_INT', 2, 0),
167+
('BINARY_OP', 0, 0),
168+
('POP_TOP', None, 0),
169+
('LOAD_CONST', 0, 0),
170+
('RETURN_VALUE', None, 0),
171+
]
172+
self.codegen_test(snippet, unoptimized, optimize_ast=False)
173+
174+
optimized = [
175+
('RESUME', 0, 0),
176+
('NOP', None, 0),
177+
('LOAD_CONST', 0, 0),
178+
('RETURN_VALUE', None, 0),
179+
]
180+
self.codegen_test(snippet, optimized, optimize_ast=True)
181+
182+
def test_match_case_fold_codegen(self):
183+
snippet = textwrap.dedent("""
184+
match 0:
185+
case -0: pass # match unary const int
186+
case -0.1: pass # match unary const float
187+
case -0j: pass # match unary const complex
188+
case 1 + 2j: pass # match const int + const complex
189+
case 1 - 2j: pass # match const int - const complex
190+
case 1.1 + 2.1j: pass # match const float + const complex
191+
case 1.1 - 2.1j: pass # match const float - const complex
192+
case -0 + 1j: pass # match unary const int + complex
193+
case -0 - 1j: pass # match unary const int - complex
194+
case -0.1 + 1.1j: pass # match unary const float + complex
195+
case -0.1 - 1.1j: pass # match unary const float - complex
196+
""")
197+
self.codegen_test(snippet, [], optimize_ast=False)

Modules/_testinternalcapi.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,11 +718,13 @@ Apply compiler code generation to an AST.
718718
static PyObject *
719719
_testinternalcapi_compiler_codegen_impl(PyObject *module, PyObject *ast,
720720
PyObject *filename, int optimize,
721-
int compile_mode)
721+
int compile_mode, int optimize_ast)
722722
/*[clinic end generated code: output=40a68f6e13951cc8 input=a0e00784f1517cd7]*/
723723
{
724-
PyCompilerFlags *flags = NULL;
725-
return _PyCompile_CodeGen(ast, filename, flags, optimize, compile_mode);
724+
PyCompilerFlags flags = _PyCompilerFlags_INIT;
725+
if (!optimize_ast)
726+
flags.cf_flags = PyCF_DONT_OPTIMIZE_AST;
727+
return _PyCompile_CodeGen(ast, filename, &flags, optimize, compile_mode);
726728
}
727729

728730

Modules/clinic/_testinternalcapi.c.h

Lines changed: 8 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/codegen.c

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5355,9 +5355,114 @@ codegen_slice(compiler *c, expr_ty s)
53555355
#define WILDCARD_STAR_CHECK(N) \
53565356
((N)->kind == MatchStar_kind && !(N)->v.MatchStar.name)
53575357

5358-
// Limit permitted subexpressions, even if the parser & AST validator let them through
5359-
#define MATCH_VALUE_EXPR(N) \
5360-
((N)->kind == Constant_kind || (N)->kind == Attribute_kind)
5358+
static bool
5359+
is_unary_or_complex_expr(expr_ty e)
5360+
{
5361+
if (e->kind != UnaryOp_kind) {
5362+
return false;
5363+
}
5364+
if (e->v.UnaryOp.op != USub) {
5365+
return false;
5366+
}
5367+
if (e->v.UnaryOp.operand->kind != Constant_kind) {
5368+
return false;
5369+
}
5370+
PyObject *constant = e->v.UnaryOp.operand->v.Constant.value;
5371+
return PyLong_CheckExact(constant) || PyFloat_CheckExact(constant) || PyComplex_CheckExact(constant);
5372+
}
5373+
5374+
static bool
5375+
is_complex_binop_expr(expr_ty e)
5376+
{
5377+
if (e->kind != BinOp_kind) {
5378+
return false;
5379+
}
5380+
if (e->v.BinOp.op != Add && e->v.BinOp.op != Sub) {
5381+
return false;
5382+
}
5383+
if (e->v.BinOp.right->kind != Constant_kind) {
5384+
return false;
5385+
}
5386+
if (e->v.BinOp.left->kind != Constant_kind && e->v.BinOp.left->kind != UnaryOp_kind) {
5387+
return false;
5388+
}
5389+
PyObject *leftconst;
5390+
if (e->v.BinOp.left->kind == UnaryOp_kind) {
5391+
if (e->v.BinOp.left->v.UnaryOp.operand->kind != Constant_kind) {
5392+
return false;
5393+
}
5394+
if (e->v.BinOp.left->v.UnaryOp.op != USub) {
5395+
return false;
5396+
}
5397+
leftconst = e->v.BinOp.left->v.UnaryOp.operand->v.Constant.value;
5398+
}
5399+
else {
5400+
leftconst = e->v.BinOp.left->v.Constant.value;
5401+
}
5402+
PyObject *rightconst = e->v.BinOp.right->v.Constant.value;
5403+
return (PyLong_CheckExact(leftconst) || PyFloat_CheckExact(leftconst)) && PyComplex_CheckExact(rightconst);
5404+
}
5405+
5406+
static void
5407+
fold_node(expr_ty node, PyObject *folded)
5408+
{
5409+
assert(node->kind != Constant_kind);
5410+
node->kind = Constant_kind;
5411+
node->v.Constant.kind = NULL;
5412+
node->v.Constant.value = folded;
5413+
}
5414+
5415+
static int
5416+
fold_unary_or_complex_expr(expr_ty e)
5417+
{
5418+
assert(e->kind == UnaryOp_kind);
5419+
assert(e->v.UnaryOp.op == USub);
5420+
assert(e->v.UnaryOp.operand->kind == Constant_kind);
5421+
PyObject *operand = e->v.UnaryOp.operand->v.Constant.value;
5422+
assert(PyLong_CheckExact(operand) || PyFloat_CheckExact(operand) || PyComplex_CheckExact(operand));
5423+
PyObject* folded = PyNumber_Negative(operand);
5424+
if (folded == NULL) {
5425+
return ERROR;
5426+
}
5427+
fold_node(e, folded);
5428+
return SUCCESS;
5429+
}
5430+
5431+
static int
5432+
fold_binary_complex_expr(expr_ty e)
5433+
{
5434+
assert(e->kind == BinOp_kind);
5435+
assert(e->v.BinOp.right->kind == Constant_kind);
5436+
assert(e->v.BinOp.left->kind == UnaryOp_kind || e->v.BinOp.left->kind == Constant_kind);
5437+
if (e->v.BinOp.left->kind == UnaryOp_kind) {
5438+
RETURN_IF_ERROR(fold_unary_or_complex_expr(e->v.BinOp.left));
5439+
}
5440+
assert(e->v.BinOp.left->kind == Constant_kind);
5441+
operator_ty op = e->v.BinOp.op;
5442+
PyObject *left = e->v.BinOp.left->v.Constant.value;
5443+
PyObject *right = e->v.BinOp.right->v.Constant.value;
5444+
assert(op == Add || op == Sub);
5445+
assert(PyLong_CheckExact(left) || PyFloat_CheckExact(left));
5446+
assert(PyComplex_CheckExact(right));
5447+
PyObject *folded = op == Add ? PyNumber_Add(left, right) : PyNumber_Subtract(left, right);
5448+
if (folded == NULL) {
5449+
return ERROR;
5450+
}
5451+
fold_node(e, folded);
5452+
return SUCCESS;
5453+
}
5454+
5455+
static int
5456+
try_fold_unary_or_binary_complex_const_expr(expr_ty key)
5457+
{
5458+
if (is_unary_or_complex_expr(key)) {
5459+
return fold_unary_or_complex_expr(key);
5460+
}
5461+
if (is_complex_binop_expr(key)) {
5462+
return fold_binary_complex_expr(key);
5463+
}
5464+
return SUCCESS;
5465+
}
53615466

53625467
// Allocate or resize pc->fail_pop to allow for n items to be popped on failure.
53635468
static int
@@ -5688,7 +5793,7 @@ codegen_pattern_mapping_key(compiler *c, PyObject *seen, pattern_ty p, Py_ssize_
56885793
location loc = LOC((pattern_ty) asdl_seq_GET(patterns, i));
56895794
return _PyCompile_Error(c, loc, e);
56905795
}
5691-
5796+
RETURN_IF_ERROR(try_fold_unary_or_binary_complex_const_expr(key));
56925797
if (key->kind == Constant_kind) {
56935798
int in_seen = PySet_Contains(seen, key->v.Constant.value);
56945799
RETURN_IF_ERROR(in_seen);
@@ -6022,7 +6127,8 @@ codegen_pattern_value(compiler *c, pattern_ty p, pattern_context *pc)
60226127
{
60236128
assert(p->kind == MatchValue_kind);
60246129
expr_ty value = p->v.MatchValue.value;
6025-
if (!MATCH_VALUE_EXPR(value)) {
6130+
RETURN_IF_ERROR(try_fold_unary_or_binary_complex_const_expr(value));
6131+
if (value->kind != Constant_kind && value->kind != Attribute_kind) {
60266132
const char *e = "patterns may only match literals and attribute lookups";
60276133
return _PyCompile_Error(c, LOC(p), e);
60286134
}

Python/compile.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ compiler_setup(compiler *c, mod_ty mod, PyObject *filename,
126126
c->c_optimize = (optimize == -1) ? _Py_GetConfig()->optimization_level : optimize;
127127
c->c_save_nested_seqs = false;
128128

129-
if (!_PyAST_Optimize(mod, arena, c->c_optimize, merged)) {
129+
int ast_opt = !(flags->cf_flags & PyCF_DONT_OPTIMIZE_AST);
130+
if (ast_opt && !_PyAST_Optimize(mod, arena, c->c_optimize, merged)) {
130131
return ERROR;
131132
}
132133
c->c_st = _PySymtable_Build(mod, filename, &c->c_future);

0 commit comments

Comments
 (0)