Skip to content

Commit b2f6602

Browse files
committed
[GR-23280] Various fixes for the pow builtin
PullRequest: graalpython/1016
2 parents b380d5d + 59315b6 commit b2f6602

File tree

8 files changed

+411
-142
lines changed

8 files changed

+411
-142
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
2+
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
3+
#
4+
# The Universal Permissive License (UPL), Version 1.0
5+
#
6+
# Subject to the condition set forth below, permission is hereby granted to any
7+
# person obtaining a copy of this software, associated documentation and/or
8+
# data (collectively the "Software"), free of charge and under any and all
9+
# copyright rights in the Software, and any and all patent rights owned or
10+
# freely licensable by each licensor hereunder covering either (i) the
11+
# unmodified Software as contributed to or provided by such licensor, or (ii)
12+
# the Larger Works (as defined below), to deal in both
13+
#
14+
# (a) the Software, and
15+
#
16+
# (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
17+
# one is included with the Software each a "Larger Work" to which the Software
18+
# is contributed by such licensors),
19+
#
20+
# without restriction, including without limitation the rights to copy, create
21+
# derivative works of, display, perform, and distribute the Software and make,
22+
# use, sell, offer for sale, import, export, have made, and have sold the
23+
# Software and the Larger Work(s), and to sublicense the foregoing rights on
24+
# either these or other terms.
25+
#
26+
# This license is subject to the following condition:
27+
#
28+
# The above copyright notice and either this complete permission notice or at a
29+
# minimum a reference to the UPL must be included in all copies or substantial
30+
# portions of the Software.
31+
#
32+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38+
# SOFTWARE.
39+
40+
import sys
41+
42+
43+
def test_pow():
44+
if sys.implementation.name == "graalpython":
45+
try:
46+
2 ** (2**128)
47+
except ArithmeticError:
48+
assert True
49+
else:
50+
assert False
51+
52+
assert 2 ** -(2**128) == 0.0
53+
54+
class X(float):
55+
def __rpow__(self, other):
56+
return 42
57+
58+
assert 2 ** X() == 42
59+
60+
try:
61+
2.0 .__pow__(2, 2)
62+
except TypeError as e:
63+
assert True
64+
else:
65+
assert False
66+
67+
try:
68+
2.0 .__pow__(2.0, 2.0)
69+
except TypeError as e:
70+
assert True
71+
else:
72+
assert False
73+
74+
assert 2 ** 2.0 == 4.0
75+
76+
assert 2 .__pow__("a") == NotImplemented
77+
assert 2 .__pow__("a", 2) == NotImplemented
78+
assert 2 .__pow__(2**30, 2**128) == 0
79+
80+
# crafted to try specializations
81+
def mypow(a, b, c):
82+
return a.__pow__(b, c)
83+
84+
values = [
85+
[1, 2, None, 1], # long
86+
[1, 128, None, 1], # BigInteger
87+
[1, -2, None, 1.0], # double result
88+
[1, 0xffffffffffffffffffffffffffffffff & 0x80, None, 1], # narrow to long
89+
[2, 0xffffffffffffffffffffffffffffffff & 0x80, None, 340282366920938463463374607431768211456], # cannot narrow
90+
[2, -(0xffffffffffffffffffffffffffffffff & 0x80), None, 2.938735877055719e-39], # double result
91+
[2**128, 0, None, 1], # narrow to long
92+
[2**128, 1, None, 340282366920938463463374607431768211456], # cannot narrow
93+
[2**128, -2, None, 8.636168555094445e-78], # double
94+
[2**128, 0xffffffffffffffffffffffffffffffff & 0x2, None, 115792089237316195423570985008687907853269984665640564039457584007913129639936], # large
95+
[2**128, -(0xffffffffffffffffffffffffffffffff & 0x8), None, 5.562684646268003e-309], # double result
96+
[1, 2, 3, 1], # fast path
97+
[2, 2**30, 2**128, 0], # generic
98+
] + []
99+
100+
if sys.version_info.minor >= 8:
101+
values += [
102+
[1, -2, 3, 1], # fast path double
103+
[1, 2, -3, -2], # negative mod
104+
[1, -2, -3, -2], # negative mod and negative right
105+
[1, -2**128, 3, 1], # mod and large negative right
106+
[1, -2**128, -3, -2], # negative mod and large negative right
107+
[1, -2**128, -2**64, -18446744073709551615], # large negative mod and large negative right
108+
]
109+
110+
for args in values:
111+
assert mypow(*args[:-1]) == args[-1], "%r -> %r == %r" % (args, mypow(*args[:-1]), args[-1])
112+
113+
def mypow_rev(a, b, c):
114+
return a.__pow__(b, c)
115+
116+
for args in reversed(values):
117+
assert mypow_rev(*args[:-1]) == args[-1], "%r -> %r == %r" % (args, mypow(*args[:-1]), args[-1])
118+
119+
assert 2**1.0 == 2.0
120+
121+
try:
122+
pow(12,-2,100)
123+
except ValueError as e:
124+
assert "base is not invertible for the given modulus" in str(e)
125+
else:
126+
assert False
127+
128+
assert pow(1234567, -2, 100) == 9

graalpython/com.oracle.graal.python.test/src/tests/test_binary_arithmetic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
2222
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
2323
# OF THE POSSIBILITY OF SUCH DAMAGE.
24+
25+
import sys
26+
2427
def test_add_long_overflow():
2528
# max long value is written as long primitive
2629
val = 0x7fffffffffffffff
@@ -295,3 +298,8 @@ def test_lshift():
295298
def test_pow():
296299
# (0xffffffffffffffff >> 63) is used to produce a non-narrowed int
297300
assert 2**(0xffffffffffffffff >> 63) == 2
301+
302+
if sys.version_info.minor >= 8:
303+
# for some reason this hangs CPython on the CI even if it's just parsed
304+
from pow_tests import test_pow
305+
test_pow()

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ private int getDebuggerSessionCount() {
16931693
}
16941694
}
16951695

1696-
@Builtin(name = POW, minNumOfPositionalArgs = 2, parameterNames = {"x", "y", "z"})
1696+
@Builtin(name = POW, minNumOfPositionalArgs = 2, parameterNames = {"base", "exp", "mod"})
16971697
@GenerateNodeFactory
16981698
public abstract static class PowNode extends PythonBuiltinNode {
16991699
@Child private LookupAndCallTernaryNode powNode = TernaryArithmetic.Pow.create();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 107 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
import com.oracle.graal.python.builtins.objects.type.LazyPythonClass;
8282
import com.oracle.graal.python.nodes.ErrorMessages;
8383
import com.oracle.graal.python.nodes.SpecialMethodNames;
84+
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
8485
import com.oracle.graal.python.nodes.call.special.LookupAndCallVarargsNode;
8586
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
8687
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
@@ -95,17 +96,22 @@
9596
import com.oracle.graal.python.runtime.formatting.FloatFormatter;
9697
import com.oracle.graal.python.runtime.formatting.InternalFormat;
9798
import com.oracle.graal.python.runtime.formatting.InternalFormat.Formatter;
99+
import com.oracle.truffle.api.CompilerDirectives;
98100
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
99101
import com.oracle.truffle.api.dsl.Cached;
102+
import com.oracle.truffle.api.dsl.Cached.Shared;
100103
import com.oracle.truffle.api.dsl.CachedContext;
101104
import com.oracle.truffle.api.dsl.Fallback;
102105
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
103106
import com.oracle.truffle.api.dsl.ImportStatic;
104107
import com.oracle.truffle.api.dsl.NodeFactory;
108+
import com.oracle.truffle.api.dsl.ReportPolymorphism;
105109
import com.oracle.truffle.api.dsl.Specialization;
106110
import com.oracle.truffle.api.dsl.TypeSystemReference;
107111
import com.oracle.truffle.api.frame.VirtualFrame;
108112
import com.oracle.truffle.api.library.CachedLibrary;
113+
import com.oracle.truffle.api.nodes.UnexpectedResultException;
114+
import com.oracle.truffle.api.profiles.BranchProfile;
109115
import com.oracle.truffle.api.profiles.ConditionProfile;
110116

111117
@CoreFunctions(extendClasses = PythonBuiltinClassType.PFloat)
@@ -451,71 +457,131 @@ PNotImplemented doGeneric(Object left, Object right) {
451457
@Builtin(name = __POW__, minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3)
452458
@TypeSystemReference(PythonArithmeticTypes.class)
453459
@GenerateNodeFactory
460+
@ReportPolymorphism
454461
abstract static class PowerNode extends PythonTernaryBuiltinNode {
455462
@Specialization
456-
double doDL(double left, long right, @SuppressWarnings("unused") PNone none) {
457-
return Math.pow(left, right);
463+
double doDL(double left, long right, @SuppressWarnings("unused") PNone none,
464+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
465+
return doOperation(left, right, negativeRaise);
458466
}
459467

460468
@Specialization
461-
double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none) {
462-
return Math.pow(left, right.doubleValue());
463-
}
464-
465-
@Specialization
466-
double doDD(double left, double right, @SuppressWarnings("unused") PNone none) {
467-
return Math.pow(left, right);
469+
double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none,
470+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
471+
return doOperation(left, right.doubleValue(), negativeRaise);
468472
}
469473

470-
@Specialization
471-
double doDL(double left, long right, long mod) {
472-
return Math.pow(left, right) % mod;
474+
/**
475+
* The special cases we need to deal with always return 1, so 0 means no special case, not a
476+
* result.
477+
*/
478+
private double doSpecialCases(double left, double right, BranchProfile negativeRaise) {
479+
// see cpython://Objects/floatobject.c#float_pow for special cases
480+
if (Double.isNaN(right) && left == 1) {
481+
// 1**nan = 1, unlike on Java
482+
return 1;
483+
}
484+
if (Double.isInfinite(right) && (left == 1 || left == -1)) {
485+
// v**(+/-)inf is 1.0 if abs(v) == 1, unlike on Java
486+
return 1;
487+
}
488+
if (left == 0 && right < 0) {
489+
negativeRaise.enter();
490+
// 0**w is an error if w is negative, unlike Java
491+
throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER);
492+
}
493+
return 0;
473494
}
474495

475-
@Specialization
476-
double doDPi(double left, PInt right, long mod) {
477-
return Math.pow(left, right.doubleValue()) % mod;
496+
private double doOperation(double left, double right, BranchProfile negativeRaise) {
497+
if (doSpecialCases(left, right, negativeRaise) == 1) {
498+
return 1.0;
499+
}
500+
return Math.pow(left, right);
478501
}
479502

480-
@Specialization
481-
double doDD(double left, double right, long mod) {
482-
return Math.pow(left, right) % mod;
503+
@Specialization(rewriteOn = UnexpectedResultException.class)
504+
double doDD(VirtualFrame frame, double left, double right, @SuppressWarnings("unused") PNone none,
505+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
506+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException {
507+
if (doSpecialCases(left, right, negativeRaise) == 1) {
508+
return 1.0;
509+
}
510+
if (left < 0 && (right % 1 != 0)) {
511+
CompilerDirectives.transferToInterpreterAndInvalidate();
512+
// Negative numbers raised to fractional powers become complex.
513+
throw new UnexpectedResultException(callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none));
514+
}
515+
return Math.pow(left, right);
483516
}
484517

485-
@Specialization
486-
double doDL(double left, long right, PInt mod) {
487-
return Math.pow(left, right) % mod.doubleValue();
518+
@Specialization(replaces = "doDD")
519+
Object doDDToComplex(VirtualFrame frame, double left, double right, PNone none,
520+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
521+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
522+
if (doSpecialCases(left, right, negativeRaise) == 1) {
523+
return 1.0;
524+
}
525+
if (left < 0 && (right % 1 != 0)) {
526+
// Negative numbers raised to fractional powers become complex.
527+
return callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none);
528+
}
529+
return Math.pow(left, right);
488530
}
489531

490-
@Specialization
491-
double doDPi(double left, PInt right, PInt mod) {
492-
return Math.pow(left, right.doubleValue()) % mod.doubleValue();
532+
@Specialization(rewriteOn = UnexpectedResultException.class)
533+
double doDL(VirtualFrame frame, long left, double right, PNone none,
534+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
535+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException {
536+
return doDD(frame, left, right, none, callPow, negativeRaise);
493537
}
494538

495-
@Specialization
496-
double doDD(double left, double right, PInt mod) {
497-
return Math.pow(left, right) % mod.doubleValue();
539+
@Specialization(replaces = "doDL")
540+
Object doDLComplex(VirtualFrame frame, long left, double right, PNone none,
541+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
542+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
543+
return doDDToComplex(frame, left, right, none, callPow, negativeRaise);
498544
}
499545

500-
@Specialization
501-
double doDL(double left, long right, double mod) {
502-
return Math.pow(left, right) % mod;
546+
@Specialization(rewriteOn = UnexpectedResultException.class)
547+
double doDPi(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none,
548+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
549+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException {
550+
return doDD(frame, left.doubleValue(), right, none, callPow, negativeRaise);
503551
}
504552

505-
@Specialization
506-
double doDPi(double left, PInt right, double mod) {
507-
return Math.pow(left, right.doubleValue()) % mod;
553+
@Specialization(replaces = "doDPi")
554+
Object doDPiToComplex(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none,
555+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
556+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
557+
return doDDToComplex(frame, left.doubleValue(), right, none, callPow, negativeRaise);
508558
}
509559

510560
@Specialization
511-
double doDD(double left, double right, double mod) {
512-
return Math.pow(left, right) % mod;
513-
}
514-
515-
@SuppressWarnings("unused")
516-
@Fallback
517-
PNotImplemented doGeneric(Object left, Object right, Object none) {
518-
return PNotImplemented.NOT_IMPLEMENTED;
561+
Object doGeneric(VirtualFrame frame, Object left, Object right, Object mod,
562+
@CachedLibrary(limit = "5") PythonObjectLibrary lib,
563+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
564+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
565+
if (!(mod instanceof PNone)) {
566+
throw raise(PythonBuiltinClassType.TypeError, "pow() 3rd argument not allowed unless all arguments are integers");
567+
}
568+
double leftDouble;
569+
double rightDouble;
570+
if (lib.canBeJavaDouble(left)) {
571+
leftDouble = lib.asJavaDouble(left);
572+
} else if (left instanceof PInt) {
573+
leftDouble = ((PInt) left).doubleValue();
574+
} else {
575+
return PNotImplemented.NOT_IMPLEMENTED;
576+
}
577+
if (lib.canBeJavaDouble(right)) {
578+
rightDouble = lib.asJavaDouble(right);
579+
} else if (right instanceof PInt) {
580+
rightDouble = ((PInt) right).doubleValue();
581+
} else {
582+
return PNotImplemented.NOT_IMPLEMENTED;
583+
}
584+
return doDDToComplex(frame, leftDouble, rightDouble, PNone.NONE, callPow, negativeRaise);
519585
}
520586
}
521587

0 commit comments

Comments
 (0)