Skip to content

Commit 5e6088e

Browse files
committed
Fix the optimizer's understanding of exponentiation
1 parent 43d8a3d commit 5e6088e

File tree

3 files changed

+146
-24
lines changed

3 files changed

+146
-24
lines changed

Python/optimizer_bytecodes.c

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,24 +167,81 @@ dummy_func(void) {
167167
}
168168

169169
op(_BINARY_OP, (left, right -- res)) {
170-
PyTypeObject *ltype = sym_get_type(left);
171-
PyTypeObject *rtype = sym_get_type(right);
172-
if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
173-
rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
174-
{
175-
if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
176-
ltype == &PyLong_Type && rtype == &PyLong_Type) {
177-
/* If both inputs are ints and the op is not division the result is an int */
178-
res = sym_new_type(ctx, &PyLong_Type);
170+
bool lhs_int = sym_matches_type(left, &PyLong_Type);
171+
bool rhs_int = sym_matches_type(right, &PyLong_Type);
172+
bool lhs_float = sym_matches_type(left, &PyFloat_Type);
173+
bool rhs_float = sym_matches_type(right, &PyFloat_Type);
174+
if ((!lhs_int && !lhs_float) || (!rhs_int && !rhs_float)) {
175+
res = sym_new_unknown(ctx);
176+
goto binary_op_done;
177+
}
178+
if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
179+
// This one's fun: the *type* of the result depends on the *values*
180+
// being exponentiated. But exponents with one constant part are
181+
// reasonably common, so it's probably worth trying to be precise:
182+
PyObject *lhs_const = sym_get_const(left);
183+
PyObject *rhs_const = sym_get_const(right);
184+
if (lhs_int && rhs_int) {
185+
if (rhs_const == NULL) {
186+
// Unknown RHS means either int or float:
187+
res = sym_new_unknown(ctx);
188+
goto binary_op_done;
189+
}
190+
if (!_PyLong_IsNegative((PyLongObject *)rhs_const)) {
191+
// Non-negative RHS means int:
192+
res = sym_new_type(ctx, &PyLong_Type);
193+
goto binary_op_done;
194+
}
195+
// Negative RHS uses float_pow...
179196
}
180-
else {
181-
/* For any other op combining ints/floats the result is a float */
197+
// Negative LHS *and* non-integral RHS means complex. So we need to
198+
// disprove at least one to prove a float result:
199+
if (rhs_int) {
200+
// Integral RHS means float:
182201
res = sym_new_type(ctx, &PyFloat_Type);
202+
goto binary_op_done;
203+
}
204+
if (rhs_const) {
205+
double rhs_double = PyFloat_AS_DOUBLE(rhs_const);
206+
if (rhs_double == floor(rhs_double)) {
207+
// Integral RHS means float:
208+
res = sym_new_type(ctx, &PyFloat_Type);
209+
goto binary_op_done;
210+
}
211+
}
212+
if (lhs_const) {
213+
if (lhs_int) {
214+
if (!_PyLong_IsNegative((PyLongObject *)lhs_const)) {
215+
// Non-negative LHS means float:
216+
res = sym_new_type(ctx, &PyFloat_Type);
217+
goto binary_op_done;
218+
}
219+
}
220+
else if (0.0 <= PyFloat_AS_DOUBLE(lhs_const)) {
221+
// Non-negative LHS means float:
222+
res = sym_new_type(ctx, &PyFloat_Type);
223+
goto binary_op_done;
224+
}
225+
if (rhs_const) {
226+
// If we have two constants and failed to disprove that it's
227+
// complex, then it's complex:
228+
res = sym_new_type(ctx, &PyComplex_Type);
229+
goto binary_op_done;
230+
}
183231
}
232+
// Couldn't prove anything. It's either float or complex:
233+
res = sym_new_unknown(ctx);
234+
}
235+
else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
236+
res = sym_new_type(ctx, &PyFloat_Type);
237+
}
238+
else if (lhs_int && rhs_int) {
239+
res = sym_new_type(ctx, &PyLong_Type);
184240
}
185241
else {
186-
res = sym_new_unknown(ctx);
242+
res = sym_new_type(ctx, &PyFloat_Type);
187243
}
244+
binary_op_done:
188245
}
189246

190247
op(_BINARY_OP_ADD_INT, (left, right -- res)) {

Python/optimizer_cases.c.h

Lines changed: 75 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Tools/cases_generator/analyzer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
596596
"_PyLong_CompactValue",
597597
"_PyLong_DigitCount",
598598
"_PyLong_IsCompact",
599+
"_PyLong_IsNegative",
599600
"_PyLong_IsNonNegativeCompact",
600601
"_PyLong_IsZero",
601602
"_PyLong_Multiply",
@@ -634,6 +635,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
634635
"advance_backoff_counter",
635636
"assert",
636637
"backoff_counter_triggers",
638+
"floor",
637639
"initial_temperature_backoff_counter",
638640
"maybe_lltrace_resume_frame",
639641
"restart_backoff_counter",

0 commit comments

Comments
 (0)