Skip to content

Commit 0b7a20a

Browse files
committed
[GR-24025] [GR-23328] pow() with 2 arguments (or None as third) should call __pow__ with 2 args
PullRequest: graalpython/1035
2 parents a3c0de6 + d154003 commit 0b7a20a

File tree

4 files changed

+52
-8
lines changed

4 files changed

+52
-8
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_binary_arithmetic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,19 @@ def test_pow():
299299
# (0xffffffffffffffff >> 63) is used to produce a non-narrowed int
300300
assert 2**(0xffffffffffffffff >> 63) == 2
301301

302+
# test that two-argument pow functions work
303+
class M:
304+
def __pow__(self, power):
305+
return (type(self), power)
306+
assert pow(M(), 2) == (M, 2)
307+
assert pow(M(), 2, None) == (M, 2)
308+
try:
309+
pow(M(), 2, 3)
310+
except TypeError:
311+
assert True
312+
else:
313+
assert False
314+
302315
if sys.version_info.minor >= 8:
303316
# for some reason this hangs CPython on the CI even if it's just parsed
304317
from pow_tests import test_pow
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
*graalpython.lib-python.3.test.test_pow.PowTest.test_bug643260
22
*graalpython.lib-python.3.test.test_pow.PowTest.test_bug705231
3+
*graalpython.lib-python.3.test.test_pow.PowTest.test_negative_exponent
34
*graalpython.lib-python.3.test.test_pow.PowTest.test_other
45
*graalpython.lib-python.3.test.test_pow.PowTest.test_powfloat
6+
*graalpython.lib-python.3.test.test_pow.PowTest.test_powint
7+

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,11 +1754,24 @@ private int getDebuggerSessionCount() {
17541754

17551755
@Builtin(name = POW, minNumOfPositionalArgs = 2, parameterNames = {"base", "exp", "mod"})
17561756
@GenerateNodeFactory
1757-
public abstract static class PowNode extends PythonBuiltinNode {
1758-
@Child private LookupAndCallTernaryNode powNode = TernaryArithmetic.Pow.create();
1757+
public abstract static class PowNode extends PythonTernaryBuiltinNode {
1758+
static LookupAndCallBinaryNode binaryPow() {
1759+
return BinaryArithmetic.Pow.create();
1760+
}
1761+
1762+
static LookupAndCallTernaryNode ternaryPow() {
1763+
return TernaryArithmetic.Pow.create();
1764+
}
17591765

17601766
@Specialization
1761-
Object doIt(VirtualFrame frame, Object x, Object y, Object z) {
1767+
Object binary(VirtualFrame frame, Object x, Object y, @SuppressWarnings("unused") PNone z,
1768+
@Cached("binaryPow()") LookupAndCallBinaryNode powNode) {
1769+
return powNode.executeObject(frame, x, y);
1770+
}
1771+
1772+
@Specialization(guards = "!isPNone(z)")
1773+
Object ternary(VirtualFrame frame, Object x, Object y, Object z,
1774+
@Cached("ternaryPow()") LookupAndCallTernaryNode powNode) {
17621775
return powNode.execute(frame, x, y, z);
17631776
}
17641777
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/IntBuiltins.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,11 @@ static BigInteger op(BigInteger a, BigInteger b) {
577577

578578
@TruffleBoundary
579579
static BigInteger opNeg(BigInteger a, BigInteger b) {
580-
if (a.equals(BigInteger.ZERO)) {
580+
if (a.signum() == 0) {
581+
return BigInteger.ZERO;
582+
}
583+
BigInteger mod = a.mod(b.negate());
584+
if (mod.signum() == 0) {
581585
return BigInteger.ZERO;
582586
}
583587
return a.mod(b.negate()).subtract(b.negate());
@@ -705,17 +709,22 @@ PInt doLLPos(long left, long right, @SuppressWarnings("unused") PNone none) {
705709
}
706710

707711
@Specialization(guards = "right < 0")
708-
static double doLLNeg(long left, long right, @SuppressWarnings("unused") PNone none) {
712+
double doLLNeg(long left, long right, @SuppressWarnings("unused") PNone none,
713+
@Shared("leftIsZero") @Cached ConditionProfile leftIsZero) {
714+
if (leftIsZero.profile(left == 0)) {
715+
throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER);
716+
}
709717
return Math.pow(left, right);
710718
}
711719

712720
@Specialization(rewriteOn = ArithmeticException.class)
713-
static Object doLPNarrow(long left, PInt right, @SuppressWarnings("unused") PNone none) {
721+
Object doLPNarrow(long left, PInt right, @SuppressWarnings("unused") PNone none,
722+
@Shared("leftIsZero") @Cached ConditionProfile leftIsZero) {
714723
long lright = right.longValueExact();
715724
if (lright >= 0) {
716725
return doLLFast(left, lright, none);
717726
}
718-
return doLLNeg(left, lright, none);
727+
return doLLNeg(left, lright, none, leftIsZero);
719728
}
720729

721730
@Specialization(replaces = "doLPNarrow")
@@ -739,7 +748,11 @@ PInt doPLPos(PInt left, long right, @SuppressWarnings("unused") PNone none) {
739748
}
740749

741750
@Specialization(guards = "right < 0")
742-
double doPLNeg(PInt left, long right, @SuppressWarnings("unused") PNone none) {
751+
double doPLNeg(PInt left, long right, @SuppressWarnings("unused") PNone none,
752+
@Shared("leftIsZero") @Cached ConditionProfile leftIsZero) {
753+
if (leftIsZero.profile(left.isZero())) {
754+
throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER);
755+
}
743756
return TrueDivNode.op(BigInteger.ONE, op(left.getValue(), -right), getRaiseNode());
744757
}
745758

@@ -886,6 +899,8 @@ private Object op(BigInteger left, BigInteger right) {
886899
// we'll raise unless left is one of the shortcut values
887900
return op(left, Long.MAX_VALUE);
888901
}
902+
} else if (left.signum() == 0) {
903+
throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER);
889904
} else {
890905
try {
891906
return Math.pow(left.longValueExact(), right.longValueExact());

0 commit comments

Comments
 (0)