Skip to content

Commit 1e5f351

Browse files
committed
deal with float.__pow__ special cases
1 parent 5b6dd2b commit 1e5f351

File tree

2 files changed

+90
-11
lines changed

2 files changed

+90
-11
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
import com.oracle.graal.python.builtins.objects.type.LazyPythonClass;
8282
import com.oracle.graal.python.nodes.ErrorMessages;
8383
import com.oracle.graal.python.nodes.SpecialMethodNames;
84+
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
8485
import com.oracle.graal.python.nodes.call.special.LookupAndCallVarargsNode;
8586
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
8687
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
@@ -95,8 +96,10 @@
9596
import com.oracle.graal.python.runtime.formatting.FloatFormatter;
9697
import com.oracle.graal.python.runtime.formatting.InternalFormat;
9798
import com.oracle.graal.python.runtime.formatting.InternalFormat.Formatter;
99+
import com.oracle.truffle.api.CompilerDirectives;
98100
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
99101
import com.oracle.truffle.api.dsl.Cached;
102+
import com.oracle.truffle.api.dsl.Cached.Shared;
100103
import com.oracle.truffle.api.dsl.CachedContext;
101104
import com.oracle.truffle.api.dsl.Fallback;
102105
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -106,6 +109,8 @@
106109
import com.oracle.truffle.api.dsl.TypeSystemReference;
107110
import com.oracle.truffle.api.frame.VirtualFrame;
108111
import com.oracle.truffle.api.library.CachedLibrary;
112+
import com.oracle.truffle.api.nodes.UnexpectedResultException;
113+
import com.oracle.truffle.api.profiles.BranchProfile;
109114
import com.oracle.truffle.api.profiles.ConditionProfile;
110115

111116
@CoreFunctions(extendClasses = PythonBuiltinClassType.PFloat)
@@ -453,28 +458,101 @@ PNotImplemented doGeneric(Object left, Object right) {
453458
@GenerateNodeFactory
454459
abstract static class PowerNode extends PythonTernaryBuiltinNode {
455460
@Specialization
456-
double doDL(double left, long right, @SuppressWarnings("unused") PNone none) {
457-
return Math.pow(left, right);
461+
double doDL(double left, long right, @SuppressWarnings("unused") PNone none,
462+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
463+
return doOperation(left, right, negativeRaise);
458464
}
459465

460466
@Specialization
461-
double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none) {
462-
return Math.pow(left, right.doubleValue());
467+
double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none,
468+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
469+
return doOperation(left, right.doubleValue(), negativeRaise);
463470
}
464471

465-
@Specialization
466-
double doDD(double left, double right, @SuppressWarnings("unused") PNone none) {
472+
/**
473+
* The special cases we need to deal with always return 1, so 0 means no special case, not a
474+
* result.
475+
*/
476+
private double doSpecialCases(double left, double right, BranchProfile negativeRaise) {
477+
// see cpython://Objects/floatobject.c#float_pow for special cases
478+
if (Double.isNaN(right) && left == 1) {
479+
// 1**nan = 1, unlike on Java
480+
return 1;
481+
}
482+
if (Double.isInfinite(right) && (left == 1 || left == -1)) {
483+
// v**(+/-)inf is 1.0 if abs(v) == 1, unlike on Java
484+
return 1;
485+
}
486+
if (left == 0 && right < 0) {
487+
negativeRaise.enter();
488+
// 0**w is an error if w is negative, unlike Java
489+
throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER);
490+
}
491+
return 0;
492+
}
493+
494+
private double doOperation(double left, double right, BranchProfile negativeRaise) {
495+
if (doSpecialCases(left, right, negativeRaise) == 1) {
496+
return 1.0;
497+
}
467498
return Math.pow(left, right);
468499
}
469500

470-
@Specialization
471-
double doDL(long left, double right, @SuppressWarnings("unused") PNone none) {
501+
@Specialization(rewriteOn = UnexpectedResultException.class)
502+
double doDD(VirtualFrame frame, double left, double right, @SuppressWarnings("unused") PNone none,
503+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
504+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException {
505+
if (doSpecialCases(left, right, negativeRaise) == 1) {
506+
return 1.0;
507+
}
508+
if (left < 0 && (right % 1 != 0)) {
509+
CompilerDirectives.transferToInterpreterAndInvalidate();
510+
// Negative numbers raised to fractional powers become complex.
511+
throw new UnexpectedResultException(callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none));
512+
}
472513
return Math.pow(left, right);
473514
}
474515

475-
@Specialization
476-
double doDPi(PInt left, double right, @SuppressWarnings("unused") PNone none) {
477-
return Math.pow(left.doubleValue(), right);
516+
@Specialization(replaces = "doDD")
517+
Object doDDToComplex(VirtualFrame frame, double left, double right, PNone none,
518+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
519+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
520+
if (doSpecialCases(left, right, negativeRaise) == 1) {
521+
return 1.0;
522+
}
523+
if (left < 0 && (right % 1 != 0)) {
524+
// Negative numbers raised to fractional powers become complex.
525+
return callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none);
526+
}
527+
return Math.pow(left, right);
528+
}
529+
530+
@Specialization(rewriteOn = UnexpectedResultException.class)
531+
double doDL(VirtualFrame frame, long left, double right, PNone none,
532+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
533+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException {
534+
return doDD(frame, left, right, none, callPow, negativeRaise);
535+
}
536+
537+
@Specialization(replaces = "doDL")
538+
Object doDLComplex(VirtualFrame frame, long left, double right, PNone none,
539+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
540+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
541+
return doDDToComplex(frame, left, right, none, callPow, negativeRaise);
542+
}
543+
544+
@Specialization(rewriteOn = UnexpectedResultException.class)
545+
double doDPi(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none,
546+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
547+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException {
548+
return doDD(frame, left.doubleValue(), right, none, callPow, negativeRaise);
549+
}
550+
551+
@Specialization(replaces = "doDPi")
552+
Object doDPiToComplex(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none,
553+
@Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow,
554+
@Shared("negativeRaise") @Cached BranchProfile negativeRaise) {
555+
return doDDToComplex(frame, left.doubleValue(), right, none, callPow, negativeRaise);
478556
}
479557

480558
@Fallback

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,5 +490,6 @@ public abstract class ErrorMessages {
490490
public static final String ZIPIMPORT_WRONG_CACHED_FILE_POS = "zipimport: wrong cached file position";
491491
public static final String ACCESS_TO_INTERNAL_LANG_NOT_PERMITTED = "access to internal language %s is not permitted";
492492
public static final String POW_BASE_NOT_INVERTIBLE = "base is not invertible for the given modulus";
493+
public static final String POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER = "0.0 cannot be raised to a negative power";
493494
public static final String POW_THIRD_ARG_CANNOT_BE_ZERO = "pow() 3rd argument cannot be 0";
494495
}

0 commit comments

Comments
 (0)