|
81 | 81 | import com.oracle.graal.python.builtins.objects.type.LazyPythonClass;
|
82 | 82 | import com.oracle.graal.python.nodes.ErrorMessages;
|
83 | 83 | import com.oracle.graal.python.nodes.SpecialMethodNames;
|
| 84 | +import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode; |
84 | 85 | import com.oracle.graal.python.nodes.call.special.LookupAndCallVarargsNode;
|
85 | 86 | import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
|
86 | 87 | import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
|
|
95 | 96 | import com.oracle.graal.python.runtime.formatting.FloatFormatter;
|
96 | 97 | import com.oracle.graal.python.runtime.formatting.InternalFormat;
|
97 | 98 | import com.oracle.graal.python.runtime.formatting.InternalFormat.Formatter;
|
| 99 | +import com.oracle.truffle.api.CompilerDirectives; |
98 | 100 | import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
|
99 | 101 | import com.oracle.truffle.api.dsl.Cached;
|
| 102 | +import com.oracle.truffle.api.dsl.Cached.Shared; |
100 | 103 | import com.oracle.truffle.api.dsl.CachedContext;
|
101 | 104 | import com.oracle.truffle.api.dsl.Fallback;
|
102 | 105 | import com.oracle.truffle.api.dsl.GenerateNodeFactory;
|
|
106 | 109 | import com.oracle.truffle.api.dsl.TypeSystemReference;
|
107 | 110 | import com.oracle.truffle.api.frame.VirtualFrame;
|
108 | 111 | import com.oracle.truffle.api.library.CachedLibrary;
|
| 112 | +import com.oracle.truffle.api.nodes.UnexpectedResultException; |
| 113 | +import com.oracle.truffle.api.profiles.BranchProfile; |
109 | 114 | import com.oracle.truffle.api.profiles.ConditionProfile;
|
110 | 115 |
|
111 | 116 | @CoreFunctions(extendClasses = PythonBuiltinClassType.PFloat)
|
@@ -453,28 +458,101 @@ PNotImplemented doGeneric(Object left, Object right) {
|
453 | 458 | @GenerateNodeFactory
|
454 | 459 | abstract static class PowerNode extends PythonTernaryBuiltinNode {
|
455 | 460 | @Specialization
|
456 |
| - double doDL(double left, long right, @SuppressWarnings("unused") PNone none) { |
457 |
| - return Math.pow(left, right); |
| 461 | + double doDL(double left, long right, @SuppressWarnings("unused") PNone none, |
| 462 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 463 | + return doOperation(left, right, negativeRaise); |
458 | 464 | }
|
459 | 465 |
|
460 | 466 | @Specialization
|
461 |
| - double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none) { |
462 |
| - return Math.pow(left, right.doubleValue()); |
| 467 | + double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none, |
| 468 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 469 | + return doOperation(left, right.doubleValue(), negativeRaise); |
463 | 470 | }
|
464 | 471 |
|
465 |
| - @Specialization |
466 |
| - double doDD(double left, double right, @SuppressWarnings("unused") PNone none) { |
| 472 | + /** |
| 473 | + * The special cases we need to deal with always return 1, so 0 means no special case, not a |
| 474 | + * result. |
| 475 | + */ |
| 476 | + private double doSpecialCases(double left, double right, BranchProfile negativeRaise) { |
| 477 | + // see cpython://Objects/floatobject.c#float_pow for special cases |
| 478 | + if (Double.isNaN(right) && left == 1) { |
| 479 | + // 1**nan = 1, unlike on Java |
| 480 | + return 1; |
| 481 | + } |
| 482 | + if (Double.isInfinite(right) && (left == 1 || left == -1)) { |
| 483 | + // v**(+/-)inf is 1.0 if abs(v) == 1, unlike on Java |
| 484 | + return 1; |
| 485 | + } |
| 486 | + if (left == 0 && right < 0) { |
| 487 | + negativeRaise.enter(); |
| 488 | + // 0**w is an error if w is negative, unlike Java |
| 489 | + throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER); |
| 490 | + } |
| 491 | + return 0; |
| 492 | + } |
| 493 | + |
| 494 | + private double doOperation(double left, double right, BranchProfile negativeRaise) { |
| 495 | + if (doSpecialCases(left, right, negativeRaise) == 1) { |
| 496 | + return 1.0; |
| 497 | + } |
467 | 498 | return Math.pow(left, right);
|
468 | 499 | }
|
469 | 500 |
|
470 |
| - @Specialization |
471 |
| - double doDL(long left, double right, @SuppressWarnings("unused") PNone none) { |
| 501 | + @Specialization(rewriteOn = UnexpectedResultException.class) |
| 502 | + double doDD(VirtualFrame frame, double left, double right, @SuppressWarnings("unused") PNone none, |
| 503 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 504 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException { |
| 505 | + if (doSpecialCases(left, right, negativeRaise) == 1) { |
| 506 | + return 1.0; |
| 507 | + } |
| 508 | + if (left < 0 && (right % 1 != 0)) { |
| 509 | + CompilerDirectives.transferToInterpreterAndInvalidate(); |
| 510 | + // Negative numbers raised to fractional powers become complex. |
| 511 | + throw new UnexpectedResultException(callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none)); |
| 512 | + } |
472 | 513 | return Math.pow(left, right);
|
473 | 514 | }
|
474 | 515 |
|
475 |
| - @Specialization |
476 |
| - double doDPi(PInt left, double right, @SuppressWarnings("unused") PNone none) { |
477 |
| - return Math.pow(left.doubleValue(), right); |
| 516 | + @Specialization(replaces = "doDD") |
| 517 | + Object doDDToComplex(VirtualFrame frame, double left, double right, PNone none, |
| 518 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 519 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 520 | + if (doSpecialCases(left, right, negativeRaise) == 1) { |
| 521 | + return 1.0; |
| 522 | + } |
| 523 | + if (left < 0 && (right % 1 != 0)) { |
| 524 | + // Negative numbers raised to fractional powers become complex. |
| 525 | + return callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none); |
| 526 | + } |
| 527 | + return Math.pow(left, right); |
| 528 | + } |
| 529 | + |
| 530 | + @Specialization(rewriteOn = UnexpectedResultException.class) |
| 531 | + double doDL(VirtualFrame frame, long left, double right, PNone none, |
| 532 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 533 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException { |
| 534 | + return doDD(frame, left, right, none, callPow, negativeRaise); |
| 535 | + } |
| 536 | + |
| 537 | + @Specialization(replaces = "doDL") |
| 538 | + Object doDLComplex(VirtualFrame frame, long left, double right, PNone none, |
| 539 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 540 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 541 | + return doDDToComplex(frame, left, right, none, callPow, negativeRaise); |
| 542 | + } |
| 543 | + |
| 544 | + @Specialization(rewriteOn = UnexpectedResultException.class) |
| 545 | + double doDPi(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none, |
| 546 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 547 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException { |
| 548 | + return doDD(frame, left.doubleValue(), right, none, callPow, negativeRaise); |
| 549 | + } |
| 550 | + |
| 551 | + @Specialization(replaces = "doDPi") |
| 552 | + Object doDPiToComplex(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none, |
| 553 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 554 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 555 | + return doDDToComplex(frame, left.doubleValue(), right, none, callPow, negativeRaise); |
478 | 556 | }
|
479 | 557 |
|
480 | 558 | @Fallback
|
|
0 commit comments