|
81 | 81 | import com.oracle.graal.python.builtins.objects.type.LazyPythonClass;
|
82 | 82 | import com.oracle.graal.python.nodes.ErrorMessages;
|
83 | 83 | import com.oracle.graal.python.nodes.SpecialMethodNames;
|
| 84 | +import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode; |
84 | 85 | import com.oracle.graal.python.nodes.call.special.LookupAndCallVarargsNode;
|
85 | 86 | import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
|
86 | 87 | import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
|
|
95 | 96 | import com.oracle.graal.python.runtime.formatting.FloatFormatter;
|
96 | 97 | import com.oracle.graal.python.runtime.formatting.InternalFormat;
|
97 | 98 | import com.oracle.graal.python.runtime.formatting.InternalFormat.Formatter;
|
| 99 | +import com.oracle.truffle.api.CompilerDirectives; |
98 | 100 | import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
|
99 | 101 | import com.oracle.truffle.api.dsl.Cached;
|
| 102 | +import com.oracle.truffle.api.dsl.Cached.Shared; |
100 | 103 | import com.oracle.truffle.api.dsl.CachedContext;
|
101 | 104 | import com.oracle.truffle.api.dsl.Fallback;
|
102 | 105 | import com.oracle.truffle.api.dsl.GenerateNodeFactory;
|
103 | 106 | import com.oracle.truffle.api.dsl.ImportStatic;
|
104 | 107 | import com.oracle.truffle.api.dsl.NodeFactory;
|
| 108 | +import com.oracle.truffle.api.dsl.ReportPolymorphism; |
105 | 109 | import com.oracle.truffle.api.dsl.Specialization;
|
106 | 110 | import com.oracle.truffle.api.dsl.TypeSystemReference;
|
107 | 111 | import com.oracle.truffle.api.frame.VirtualFrame;
|
108 | 112 | import com.oracle.truffle.api.library.CachedLibrary;
|
| 113 | +import com.oracle.truffle.api.nodes.UnexpectedResultException; |
| 114 | +import com.oracle.truffle.api.profiles.BranchProfile; |
109 | 115 | import com.oracle.truffle.api.profiles.ConditionProfile;
|
110 | 116 |
|
111 | 117 | @CoreFunctions(extendClasses = PythonBuiltinClassType.PFloat)
|
@@ -451,71 +457,131 @@ PNotImplemented doGeneric(Object left, Object right) {
|
451 | 457 | @Builtin(name = __POW__, minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3)
|
452 | 458 | @TypeSystemReference(PythonArithmeticTypes.class)
|
453 | 459 | @GenerateNodeFactory
|
| 460 | + @ReportPolymorphism |
454 | 461 | abstract static class PowerNode extends PythonTernaryBuiltinNode {
|
455 | 462 | @Specialization
|
456 |
| - double doDL(double left, long right, @SuppressWarnings("unused") PNone none) { |
457 |
| - return Math.pow(left, right); |
| 463 | + double doDL(double left, long right, @SuppressWarnings("unused") PNone none, |
| 464 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 465 | + return doOperation(left, right, negativeRaise); |
458 | 466 | }
|
459 | 467 |
|
460 | 468 | @Specialization
|
461 |
| - double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none) { |
462 |
| - return Math.pow(left, right.doubleValue()); |
463 |
| - } |
464 |
| - |
465 |
| - @Specialization |
466 |
| - double doDD(double left, double right, @SuppressWarnings("unused") PNone none) { |
467 |
| - return Math.pow(left, right); |
| 469 | + double doDPi(double left, PInt right, @SuppressWarnings("unused") PNone none, |
| 470 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 471 | + return doOperation(left, right.doubleValue(), negativeRaise); |
468 | 472 | }
|
469 | 473 |
|
470 |
| - @Specialization |
471 |
| - double doDL(double left, long right, long mod) { |
472 |
| - return Math.pow(left, right) % mod; |
| 474 | + /** |
| 475 | + * The special cases we need to deal with always return 1, so 0 means no special case, not a |
| 476 | + * result. |
| 477 | + */ |
| 478 | + private double doSpecialCases(double left, double right, BranchProfile negativeRaise) { |
| 479 | + // see cpython://Objects/floatobject.c#float_pow for special cases |
| 480 | + if (Double.isNaN(right) && left == 1) { |
| 481 | + // 1**nan = 1, unlike on Java |
| 482 | + return 1; |
| 483 | + } |
| 484 | + if (Double.isInfinite(right) && (left == 1 || left == -1)) { |
| 485 | + // v**(+/-)inf is 1.0 if abs(v) == 1, unlike on Java |
| 486 | + return 1; |
| 487 | + } |
| 488 | + if (left == 0 && right < 0) { |
| 489 | + negativeRaise.enter(); |
| 490 | + // 0**w is an error if w is negative, unlike Java |
| 491 | + throw raise(PythonBuiltinClassType.ZeroDivisionError, ErrorMessages.POW_ZERO_CANNOT_RAISE_TO_NEGATIVE_POWER); |
| 492 | + } |
| 493 | + return 0; |
473 | 494 | }
|
474 | 495 |
|
475 |
| - @Specialization |
476 |
| - double doDPi(double left, PInt right, long mod) { |
477 |
| - return Math.pow(left, right.doubleValue()) % mod; |
| 496 | + private double doOperation(double left, double right, BranchProfile negativeRaise) { |
| 497 | + if (doSpecialCases(left, right, negativeRaise) == 1) { |
| 498 | + return 1.0; |
| 499 | + } |
| 500 | + return Math.pow(left, right); |
478 | 501 | }
|
479 | 502 |
|
480 |
| - @Specialization |
481 |
| - double doDD(double left, double right, long mod) { |
482 |
| - return Math.pow(left, right) % mod; |
| 503 | + @Specialization(rewriteOn = UnexpectedResultException.class) |
| 504 | + double doDD(VirtualFrame frame, double left, double right, @SuppressWarnings("unused") PNone none, |
| 505 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 506 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException { |
| 507 | + if (doSpecialCases(left, right, negativeRaise) == 1) { |
| 508 | + return 1.0; |
| 509 | + } |
| 510 | + if (left < 0 && (right % 1 != 0)) { |
| 511 | + CompilerDirectives.transferToInterpreterAndInvalidate(); |
| 512 | + // Negative numbers raised to fractional powers become complex. |
| 513 | + throw new UnexpectedResultException(callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none)); |
| 514 | + } |
| 515 | + return Math.pow(left, right); |
483 | 516 | }
|
484 | 517 |
|
485 |
| - @Specialization |
486 |
| - double doDL(double left, long right, PInt mod) { |
487 |
| - return Math.pow(left, right) % mod.doubleValue(); |
| 518 | + @Specialization(replaces = "doDD") |
| 519 | + Object doDDToComplex(VirtualFrame frame, double left, double right, PNone none, |
| 520 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 521 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 522 | + if (doSpecialCases(left, right, negativeRaise) == 1) { |
| 523 | + return 1.0; |
| 524 | + } |
| 525 | + if (left < 0 && (right % 1 != 0)) { |
| 526 | + // Negative numbers raised to fractional powers become complex. |
| 527 | + return callPow.execute(frame, factory().createComplex(left, 0), factory().createComplex(right, 0), none); |
| 528 | + } |
| 529 | + return Math.pow(left, right); |
488 | 530 | }
|
489 | 531 |
|
490 |
| - @Specialization |
491 |
| - double doDPi(double left, PInt right, PInt mod) { |
492 |
| - return Math.pow(left, right.doubleValue()) % mod.doubleValue(); |
| 532 | + @Specialization(rewriteOn = UnexpectedResultException.class) |
| 533 | + double doDL(VirtualFrame frame, long left, double right, PNone none, |
| 534 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 535 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException { |
| 536 | + return doDD(frame, left, right, none, callPow, negativeRaise); |
493 | 537 | }
|
494 | 538 |
|
495 |
| - @Specialization |
496 |
| - double doDD(double left, double right, PInt mod) { |
497 |
| - return Math.pow(left, right) % mod.doubleValue(); |
| 539 | + @Specialization(replaces = "doDL") |
| 540 | + Object doDLComplex(VirtualFrame frame, long left, double right, PNone none, |
| 541 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 542 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 543 | + return doDDToComplex(frame, left, right, none, callPow, negativeRaise); |
498 | 544 | }
|
499 | 545 |
|
500 |
| - @Specialization |
501 |
| - double doDL(double left, long right, double mod) { |
502 |
| - return Math.pow(left, right) % mod; |
| 546 | + @Specialization(rewriteOn = UnexpectedResultException.class) |
| 547 | + double doDPi(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none, |
| 548 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 549 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) throws UnexpectedResultException { |
| 550 | + return doDD(frame, left.doubleValue(), right, none, callPow, negativeRaise); |
503 | 551 | }
|
504 | 552 |
|
505 |
| - @Specialization |
506 |
| - double doDPi(double left, PInt right, double mod) { |
507 |
| - return Math.pow(left, right.doubleValue()) % mod; |
| 553 | + @Specialization(replaces = "doDPi") |
| 554 | + Object doDPiToComplex(VirtualFrame frame, PInt left, double right, @SuppressWarnings("unused") PNone none, |
| 555 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 556 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 557 | + return doDDToComplex(frame, left.doubleValue(), right, none, callPow, negativeRaise); |
508 | 558 | }
|
509 | 559 |
|
510 | 560 | @Specialization
|
511 |
| - double doDD(double left, double right, double mod) { |
512 |
| - return Math.pow(left, right) % mod; |
513 |
| - } |
514 |
| - |
515 |
| - @SuppressWarnings("unused") |
516 |
| - @Fallback |
517 |
| - PNotImplemented doGeneric(Object left, Object right, Object none) { |
518 |
| - return PNotImplemented.NOT_IMPLEMENTED; |
| 561 | + Object doGeneric(VirtualFrame frame, Object left, Object right, Object mod, |
| 562 | + @CachedLibrary(limit = "5") PythonObjectLibrary lib, |
| 563 | + @Shared("powCall") @Cached("create(__POW__)") LookupAndCallTernaryNode callPow, |
| 564 | + @Shared("negativeRaise") @Cached BranchProfile negativeRaise) { |
| 565 | + if (!(mod instanceof PNone)) { |
| 566 | + throw raise(PythonBuiltinClassType.TypeError, "pow() 3rd argument not allowed unless all arguments are integers"); |
| 567 | + } |
| 568 | + double leftDouble; |
| 569 | + double rightDouble; |
| 570 | + if (lib.canBeJavaDouble(left)) { |
| 571 | + leftDouble = lib.asJavaDouble(left); |
| 572 | + } else if (left instanceof PInt) { |
| 573 | + leftDouble = ((PInt) left).doubleValue(); |
| 574 | + } else { |
| 575 | + return PNotImplemented.NOT_IMPLEMENTED; |
| 576 | + } |
| 577 | + if (lib.canBeJavaDouble(right)) { |
| 578 | + rightDouble = lib.asJavaDouble(right); |
| 579 | + } else if (right instanceof PInt) { |
| 580 | + rightDouble = ((PInt) right).doubleValue(); |
| 581 | + } else { |
| 582 | + return PNotImplemented.NOT_IMPLEMENTED; |
| 583 | + } |
| 584 | + return doDDToComplex(frame, leftDouble, rightDouble, PNone.NONE, callPow, negativeRaise); |
519 | 585 | }
|
520 | 586 | }
|
521 | 587 |
|
|
0 commit comments