55
55
import static com .oracle .graal .python .nodes .SpecialMethodNames .__NEG__ ;
56
56
import static com .oracle .graal .python .nodes .SpecialMethodNames .__NE__ ;
57
57
import static com .oracle .graal .python .nodes .SpecialMethodNames .__POS__ ;
58
+ import static com .oracle .graal .python .nodes .SpecialMethodNames .__POW__ ;
58
59
import static com .oracle .graal .python .nodes .SpecialMethodNames .__RADD__ ;
59
60
import static com .oracle .graal .python .nodes .SpecialMethodNames .__REPR__ ;
60
61
import static com .oracle .graal .python .nodes .SpecialMethodNames .__RMUL__ ;
62
+ import static com .oracle .graal .python .nodes .SpecialMethodNames .__RPOW__ ;
61
63
import static com .oracle .graal .python .nodes .SpecialMethodNames .__RSUB__ ;
62
64
import static com .oracle .graal .python .nodes .SpecialMethodNames .__RTRUEDIV__ ;
63
65
import static com .oracle .graal .python .nodes .SpecialMethodNames .__STR__ ;
64
66
import static com .oracle .graal .python .nodes .SpecialMethodNames .__SUB__ ;
65
67
import static com .oracle .graal .python .nodes .SpecialMethodNames .__TRUEDIV__ ;
68
+ import static com .oracle .graal .python .runtime .exception .PythonErrorType .OverflowError ;
69
+ import static com .oracle .graal .python .runtime .exception .PythonErrorType .ValueError ;
70
+ import static com .oracle .graal .python .runtime .exception .PythonErrorType .ZeroDivisionError ;
66
71
67
72
import java .util .List ;
68
73
69
74
import com .oracle .graal .python .builtins .Builtin ;
70
75
import com .oracle .graal .python .builtins .CoreFunctions ;
71
76
import com .oracle .graal .python .builtins .PythonBuiltinClassType ;
72
77
import com .oracle .graal .python .builtins .PythonBuiltins ;
78
+ import com .oracle .graal .python .builtins .objects .PNone ;
73
79
import com .oracle .graal .python .builtins .objects .PNotImplemented ;
74
80
import com .oracle .graal .python .builtins .objects .ints .PInt ;
75
81
import com .oracle .graal .python .builtins .objects .tuple .PTuple ;
76
82
import com .oracle .graal .python .nodes .ErrorMessages ;
77
83
import com .oracle .graal .python .nodes .function .PythonBuiltinBaseNode ;
78
84
import com .oracle .graal .python .nodes .function .PythonBuiltinNode ;
79
85
import com .oracle .graal .python .nodes .function .builtins .PythonBinaryBuiltinNode ;
86
+ import com .oracle .graal .python .nodes .function .builtins .PythonTernaryBuiltinNode ;
80
87
import com .oracle .graal .python .nodes .function .builtins .PythonUnaryBuiltinNode ;
81
88
import com .oracle .graal .python .nodes .truffle .PythonArithmeticTypes ;
89
+ import com .oracle .graal .python .nodes .util .CoerceToComplexNode ;
82
90
import com .oracle .graal .python .runtime .exception .PythonErrorType ;
91
+ import com .oracle .graal .python .runtime .object .PythonObjectFactory ;
83
92
import com .oracle .truffle .api .CompilerDirectives .TruffleBoundary ;
84
93
import com .oracle .truffle .api .dsl .Cached ;
85
94
import com .oracle .truffle .api .dsl .Fallback ;
86
95
import com .oracle .truffle .api .dsl .GenerateNodeFactory ;
87
96
import com .oracle .truffle .api .dsl .NodeFactory ;
97
+ import com .oracle .truffle .api .dsl .ReportPolymorphism ;
88
98
import com .oracle .truffle .api .dsl .Specialization ;
89
99
import com .oracle .truffle .api .dsl .TypeSystemReference ;
90
100
import com .oracle .truffle .api .frame .VirtualFrame ;
@@ -304,8 +314,8 @@ PComplex doComplexInt(PComplex left, long right) {
304
314
}
305
315
306
316
@ Specialization
307
- PComplex doComplexPInt (PComplex right , PInt left ) {
308
- return doComplexDouble (right , left .doubleValue ());
317
+ PComplex doComplexPInt (PComplex left , PInt right ) {
318
+ return doComplexDouble (left , right .doubleValue ());
309
319
}
310
320
311
321
@ Specialization
@@ -338,34 +348,38 @@ PComplex doComplex(PComplex left, PComplex right,
338
348
339
349
@ Specialization
340
350
PComplex doComplexDouble (double left , PComplex right ) {
341
- double oprealSq = right .getReal () * right .getReal ();
342
- double opimagSq = right .getImag () * right .getImag ();
343
- double twice = 2 * right .getImag () * right .getReal ();
344
- double realPart = right .getReal () * left ;
345
- double imagPart = right .getImag () * left ;
346
- return factory ().createComplex (realPart / (oprealSq + opimagSq ), -imagPart / twice );
351
+ return doubleDivComplex (left , right , factory ());
347
352
}
348
353
349
354
@ Specialization
350
355
PComplex doComplexInt (long left , PComplex right ) {
351
356
double oprealSq = right .getReal () * right .getReal ();
352
357
double opimagSq = right .getImag () * right .getImag ();
353
- double twice = 2 * right .getImag () * right .getReal ();
354
358
double realPart = right .getReal () * left ;
355
359
double imagPart = right .getImag () * left ;
356
- return factory ().createComplex (realPart / (oprealSq + opimagSq ), -imagPart / twice );
360
+ double denom = oprealSq + opimagSq ;
361
+ return factory ().createComplex (realPart / denom , -imagPart / denom );
357
362
}
358
363
359
364
@ Specialization
360
365
PComplex doComplexPInt (PInt left , PComplex right ) {
361
- return doComplexDouble (right , left .doubleValue ());
366
+ return doComplexDouble (left .doubleValue (), right );
362
367
}
363
368
364
369
@ SuppressWarnings ("unused" )
365
370
@ Fallback
366
371
PNotImplemented doComplex (Object left , Object right ) {
367
372
return PNotImplemented .NOT_IMPLEMENTED ;
368
373
}
374
+
375
+ static PComplex doubleDivComplex (double left , PComplex right , PythonObjectFactory factory ) {
376
+ double oprealSq = right .getReal () * right .getReal ();
377
+ double opimagSq = right .getImag () * right .getImag ();
378
+ double realPart = right .getReal () * left ;
379
+ double imagPart = right .getImag () * left ;
380
+ double denom = oprealSq + opimagSq ;
381
+ return factory .createComplex (realPart / denom , -imagPart / denom );
382
+ }
369
383
}
370
384
371
385
@ GenerateNodeFactory
@@ -401,9 +415,7 @@ PComplex doComplexDouble(PComplex left, double right) {
401
415
402
416
@ Specialization
403
417
PComplex doComplex (PComplex left , PComplex right ) {
404
- double newReal = left .getReal () * right .getReal () - left .getImag () * right .getImag ();
405
- double newImag = left .getReal () * right .getImag () + left .getImag () * right .getReal ();
406
- return factory ().createComplex (newReal , newImag );
418
+ return multiply (left , right , factory ());
407
419
}
408
420
409
421
@ Specialization
@@ -421,6 +433,12 @@ PComplex doComplexPInt(PComplex left, PInt right) {
421
433
PNotImplemented doGeneric (Object left , Object right ) {
422
434
return PNotImplemented .NOT_IMPLEMENTED ;
423
435
}
436
+
437
+ static PComplex multiply (PComplex left , PComplex right , PythonObjectFactory factory ) {
438
+ double newReal = left .getReal () * right .getReal () - left .getImag () * right .getImag ();
439
+ double newImag = left .getReal () * right .getImag () + left .getImag () * right .getReal ();
440
+ return factory .createComplex (newReal , newImag );
441
+ }
424
442
}
425
443
426
444
@ GenerateNodeFactory
@@ -460,6 +478,98 @@ PNotImplemented doComplex(Object left, Object right) {
460
478
}
461
479
}
462
480
481
+ @ Builtin (name = __RPOW__ , minNumOfPositionalArgs = 2 , maxNumOfPositionalArgs = 3 , reverseOperation = true )
482
+ @ Builtin (name = __POW__ , minNumOfPositionalArgs = 2 , maxNumOfPositionalArgs = 3 )
483
+ @ TypeSystemReference (PythonArithmeticTypes .class )
484
+ @ GenerateNodeFactory
485
+ @ ReportPolymorphism
486
+ abstract static class PowerNode extends PythonTernaryBuiltinNode {
487
+
488
+ static boolean isSmallPositive (long l ) {
489
+ return l > 0 && l <= 100 ;
490
+ }
491
+
492
+ static boolean isSmallNegative (long l ) {
493
+ return l <= 0 && l >= -100 ;
494
+ }
495
+
496
+ @ Specialization (guards = "isSmallPositive(right)" )
497
+ PComplex doComplexLongSmallPos (PComplex left , long right , @ SuppressWarnings ("unused" ) PNone mod ) {
498
+ return checkOverflow (complexToSmallPositiveIntPower (left , right ));
499
+ }
500
+
501
+ @ Specialization (guards = "isSmallNegative(right)" )
502
+ PComplex doComplexLongSmallNeg (PComplex left , long right , @ SuppressWarnings ("unused" ) PNone mod ) {
503
+ return checkOverflow (DivNode .doubleDivComplex (1.0 , complexToSmallPositiveIntPower (left , -right ), factory ()));
504
+ }
505
+
506
+ @ Specialization (guards = "!isSmallPositive(right) || !isSmallNegative(right)" )
507
+ PComplex doComplexLong (PComplex left , long right , @ SuppressWarnings ("unused" ) PNone mod ) {
508
+ return checkOverflow (complexToComplexPower (left , factory ().createComplex (right , 0.0 )));
509
+ }
510
+
511
+ @ Specialization
512
+ PComplex doComplexComplex (PComplex left , PComplex right , @ SuppressWarnings ("unused" ) PNone mod ) {
513
+ return checkOverflow (complexToComplexPower (left , right ));
514
+ }
515
+
516
+ @ Specialization
517
+ PComplex doGeneric (VirtualFrame frame , Object left , Object right , @ SuppressWarnings ("unused" ) PNone mod ,
518
+ @ Cached CoerceToComplexNode coerceLeft ,
519
+ @ Cached CoerceToComplexNode coerceRight ) {
520
+ return checkOverflow (complexToComplexPower (coerceLeft .execute (frame , left ), coerceRight .execute (frame , right )));
521
+ }
522
+
523
+ @ Specialization (guards = "!isPNone(mod)" )
524
+ @ SuppressWarnings ("unused" )
525
+ Object doGeneric (Object left , Object right , Object mod ) {
526
+ throw raise (ValueError , ErrorMessages .COMPLEX_MODULO );
527
+ }
528
+
529
+ private PComplex complexToSmallPositiveIntPower (PComplex x , long n ) {
530
+ long mask = 1 ;
531
+ PComplex r = factory ().createComplex (1.0 , 0.0 );
532
+ PComplex p = x ;
533
+ while (mask > 0 && n >= mask ) {
534
+ if ((n & mask ) != 0 ) {
535
+ r = MulNode .multiply (r , p , factory ());
536
+ }
537
+ mask <<= 1 ;
538
+ p = MulNode .multiply (p , p , factory ());
539
+ }
540
+ return r ;
541
+ }
542
+
543
+ @ TruffleBoundary
544
+ private PComplex complexToComplexPower (PComplex a , PComplex b ) {
545
+ if (b .getReal () == 0.0 && b .getImag () == 0.0 ) {
546
+ return factory ().createComplex (1.0 , 0.0 );
547
+ }
548
+ if (a .getReal () == 0.0 && a .getImag () == 0.0 ) {
549
+ if (b .getImag () != 0.0 || b .getReal () < 0.0 ) {
550
+ throw raise (ZeroDivisionError , ErrorMessages .COMPLEX_ZERO_TO_NEGATIVE_POWER );
551
+ }
552
+ return factory ().createComplex (0.0 , 0.0 );
553
+ }
554
+ double vabs = Math .hypot (a .getReal (), a .getImag ());
555
+ double len = Math .pow (vabs , b .getReal ());
556
+ double at = Math .atan2 (a .getImag (), a .getReal ());
557
+ double phase = at * b .getReal ();
558
+ if (b .getImag () != 0.0 ) {
559
+ len /= Math .exp (at * b .getImag ());
560
+ phase += b .getImag () * Math .log (vabs );
561
+ }
562
+ return factory ().createComplex (len * Math .cos (phase ), len * Math .sin (phase ));
563
+ }
564
+
565
+ private PComplex checkOverflow (PComplex result ) {
566
+ if (Double .isInfinite (result .getReal ()) || Double .isInfinite (result .getImag ())) {
567
+ throw raise (OverflowError , ErrorMessages .COMPLEX_EXPONENTIATION );
568
+ }
569
+ return result ;
570
+ }
571
+ }
572
+
463
573
@ GenerateNodeFactory
464
574
@ Builtin (name = __EQ__ , minNumOfPositionalArgs = 2 )
465
575
@ TypeSystemReference (PythonArithmeticTypes .class )
0 commit comments