Skip to content

Commit bf47dec

Browse files
committed
Implementation of complex.__pow__ and fixes to __truediv__
1 parent fa42168 commit bf47dec

File tree

3 files changed

+128
-15
lines changed

3 files changed

+128
-15
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/complex/ComplexBuiltins.java

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,46 @@
5555
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NEG__;
5656
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NE__;
5757
import static com.oracle.graal.python.nodes.SpecialMethodNames.__POS__;
58+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__POW__;
5859
import static com.oracle.graal.python.nodes.SpecialMethodNames.__RADD__;
5960
import static com.oracle.graal.python.nodes.SpecialMethodNames.__REPR__;
6061
import static com.oracle.graal.python.nodes.SpecialMethodNames.__RMUL__;
62+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__RPOW__;
6163
import static com.oracle.graal.python.nodes.SpecialMethodNames.__RSUB__;
6264
import static com.oracle.graal.python.nodes.SpecialMethodNames.__RTRUEDIV__;
6365
import static com.oracle.graal.python.nodes.SpecialMethodNames.__STR__;
6466
import static com.oracle.graal.python.nodes.SpecialMethodNames.__SUB__;
6567
import static com.oracle.graal.python.nodes.SpecialMethodNames.__TRUEDIV__;
68+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.OverflowError;
69+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
70+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ZeroDivisionError;
6671

6772
import java.util.List;
6873

6974
import com.oracle.graal.python.builtins.Builtin;
7075
import com.oracle.graal.python.builtins.CoreFunctions;
7176
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
7277
import com.oracle.graal.python.builtins.PythonBuiltins;
78+
import com.oracle.graal.python.builtins.objects.PNone;
7379
import com.oracle.graal.python.builtins.objects.PNotImplemented;
7480
import com.oracle.graal.python.builtins.objects.ints.PInt;
7581
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
7682
import com.oracle.graal.python.nodes.ErrorMessages;
7783
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
7884
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
7985
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
86+
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
8087
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
8188
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
89+
import com.oracle.graal.python.nodes.util.CoerceToComplexNode;
8290
import com.oracle.graal.python.runtime.exception.PythonErrorType;
91+
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
8392
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
8493
import com.oracle.truffle.api.dsl.Cached;
8594
import com.oracle.truffle.api.dsl.Fallback;
8695
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
8796
import com.oracle.truffle.api.dsl.NodeFactory;
97+
import com.oracle.truffle.api.dsl.ReportPolymorphism;
8898
import com.oracle.truffle.api.dsl.Specialization;
8999
import com.oracle.truffle.api.dsl.TypeSystemReference;
90100
import com.oracle.truffle.api.frame.VirtualFrame;
@@ -304,8 +314,8 @@ PComplex doComplexInt(PComplex left, long right) {
304314
}
305315

306316
@Specialization
307-
PComplex doComplexPInt(PComplex right, PInt left) {
308-
return doComplexDouble(right, left.doubleValue());
317+
PComplex doComplexPInt(PComplex left, PInt right) {
318+
return doComplexDouble(left, right.doubleValue());
309319
}
310320

311321
@Specialization
@@ -338,34 +348,38 @@ PComplex doComplex(PComplex left, PComplex right,
338348

339349
@Specialization
340350
PComplex doComplexDouble(double left, PComplex right) {
341-
double oprealSq = right.getReal() * right.getReal();
342-
double opimagSq = right.getImag() * right.getImag();
343-
double twice = 2 * right.getImag() * right.getReal();
344-
double realPart = right.getReal() * left;
345-
double imagPart = right.getImag() * left;
346-
return factory().createComplex(realPart / (oprealSq + opimagSq), -imagPart / twice);
351+
return doubleDivComplex(left, right, factory());
347352
}
348353

349354
@Specialization
350355
PComplex doComplexInt(long left, PComplex right) {
351356
double oprealSq = right.getReal() * right.getReal();
352357
double opimagSq = right.getImag() * right.getImag();
353-
double twice = 2 * right.getImag() * right.getReal();
354358
double realPart = right.getReal() * left;
355359
double imagPart = right.getImag() * left;
356-
return factory().createComplex(realPart / (oprealSq + opimagSq), -imagPart / twice);
360+
double denom = oprealSq + opimagSq;
361+
return factory().createComplex(realPart / denom, -imagPart / denom);
357362
}
358363

359364
@Specialization
360365
PComplex doComplexPInt(PInt left, PComplex right) {
361-
return doComplexDouble(right, left.doubleValue());
366+
return doComplexDouble(left.doubleValue(), right);
362367
}
363368

364369
@SuppressWarnings("unused")
365370
@Fallback
366371
PNotImplemented doComplex(Object left, Object right) {
367372
return PNotImplemented.NOT_IMPLEMENTED;
368373
}
374+
375+
static PComplex doubleDivComplex(double left, PComplex right, PythonObjectFactory factory) {
376+
double oprealSq = right.getReal() * right.getReal();
377+
double opimagSq = right.getImag() * right.getImag();
378+
double realPart = right.getReal() * left;
379+
double imagPart = right.getImag() * left;
380+
double denom = oprealSq + opimagSq;
381+
return factory.createComplex(realPart / denom, -imagPart / denom);
382+
}
369383
}
370384

371385
@GenerateNodeFactory
@@ -401,9 +415,7 @@ PComplex doComplexDouble(PComplex left, double right) {
401415

402416
@Specialization
403417
PComplex doComplex(PComplex left, PComplex right) {
404-
double newReal = left.getReal() * right.getReal() - left.getImag() * right.getImag();
405-
double newImag = left.getReal() * right.getImag() + left.getImag() * right.getReal();
406-
return factory().createComplex(newReal, newImag);
418+
return multiply(left, right, factory());
407419
}
408420

409421
@Specialization
@@ -421,6 +433,12 @@ PComplex doComplexPInt(PComplex left, PInt right) {
421433
PNotImplemented doGeneric(Object left, Object right) {
422434
return PNotImplemented.NOT_IMPLEMENTED;
423435
}
436+
437+
static PComplex multiply(PComplex left, PComplex right, PythonObjectFactory factory) {
438+
double newReal = left.getReal() * right.getReal() - left.getImag() * right.getImag();
439+
double newImag = left.getReal() * right.getImag() + left.getImag() * right.getReal();
440+
return factory.createComplex(newReal, newImag);
441+
}
424442
}
425443

426444
@GenerateNodeFactory
@@ -460,6 +478,98 @@ PNotImplemented doComplex(Object left, Object right) {
460478
}
461479
}
462480

481+
@Builtin(name = __RPOW__, minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3, reverseOperation = true)
482+
@Builtin(name = __POW__, minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3)
483+
@TypeSystemReference(PythonArithmeticTypes.class)
484+
@GenerateNodeFactory
485+
@ReportPolymorphism
486+
abstract static class PowerNode extends PythonTernaryBuiltinNode {
487+
488+
static boolean isSmallPositive(long l) {
489+
return l > 0 && l <= 100;
490+
}
491+
492+
static boolean isSmallNegative(long l) {
493+
return l <= 0 && l >= -100;
494+
}
495+
496+
@Specialization(guards = "isSmallPositive(right)")
497+
PComplex doComplexLongSmallPos(PComplex left, long right, @SuppressWarnings("unused") PNone mod) {
498+
return checkOverflow(complexToSmallPositiveIntPower(left, right));
499+
}
500+
501+
@Specialization(guards = "isSmallNegative(right)")
502+
PComplex doComplexLongSmallNeg(PComplex left, long right, @SuppressWarnings("unused") PNone mod) {
503+
return checkOverflow(DivNode.doubleDivComplex(1.0, complexToSmallPositiveIntPower(left, -right), factory()));
504+
}
505+
506+
@Specialization(guards = "!isSmallPositive(right) || !isSmallNegative(right)")
507+
PComplex doComplexLong(PComplex left, long right, @SuppressWarnings("unused") PNone mod) {
508+
return checkOverflow(complexToComplexPower(left, factory().createComplex(right, 0.0)));
509+
}
510+
511+
@Specialization
512+
PComplex doComplexComplex(PComplex left, PComplex right, @SuppressWarnings("unused") PNone mod) {
513+
return checkOverflow(complexToComplexPower(left, right));
514+
}
515+
516+
@Specialization
517+
PComplex doGeneric(VirtualFrame frame, Object left, Object right, @SuppressWarnings("unused") PNone mod,
518+
@Cached CoerceToComplexNode coerceLeft,
519+
@Cached CoerceToComplexNode coerceRight) {
520+
return checkOverflow(complexToComplexPower(coerceLeft.execute(frame, left), coerceRight.execute(frame, right)));
521+
}
522+
523+
@Specialization(guards = "!isPNone(mod)")
524+
@SuppressWarnings("unused")
525+
Object doGeneric(Object left, Object right, Object mod) {
526+
throw raise(ValueError, ErrorMessages.COMPLEX_MODULO);
527+
}
528+
529+
private PComplex complexToSmallPositiveIntPower(PComplex x, long n) {
530+
long mask = 1;
531+
PComplex r = factory().createComplex(1.0, 0.0);
532+
PComplex p = x;
533+
while (mask > 0 && n >= mask) {
534+
if ((n & mask) != 0) {
535+
r = MulNode.multiply(r, p, factory());
536+
}
537+
mask <<= 1;
538+
p = MulNode.multiply(p, p, factory());
539+
}
540+
return r;
541+
}
542+
543+
@TruffleBoundary
544+
private PComplex complexToComplexPower(PComplex a, PComplex b) {
545+
if (b.getReal() == 0.0 && b.getImag() == 0.0) {
546+
return factory().createComplex(1.0, 0.0);
547+
}
548+
if (a.getReal() == 0.0 && a.getImag() == 0.0) {
549+
if (b.getImag() != 0.0 || b.getReal() < 0.0) {
550+
throw raise(ZeroDivisionError, ErrorMessages.COMPLEX_ZERO_TO_NEGATIVE_POWER);
551+
}
552+
return factory().createComplex(0.0, 0.0);
553+
}
554+
double vabs = Math.hypot(a.getReal(), a.getImag());
555+
double len = Math.pow(vabs, b.getReal());
556+
double at = Math.atan2(a.getImag(), a.getReal());
557+
double phase = at * b.getReal();
558+
if (b.getImag() != 0.0) {
559+
len /= Math.exp(at * b.getImag());
560+
phase += b.getImag() * Math.log(vabs);
561+
}
562+
return factory().createComplex(len * Math.cos(phase), len * Math.sin(phase));
563+
}
564+
565+
private PComplex checkOverflow(PComplex result) {
566+
if (Double.isInfinite(result.getReal()) || Double.isInfinite(result.getImag())) {
567+
throw raise(OverflowError, ErrorMessages.COMPLEX_EXPONENTIATION);
568+
}
569+
return result;
570+
}
571+
}
572+
463573
@GenerateNodeFactory
464574
@Builtin(name = __EQ__, minNumOfPositionalArgs = 2)
465575
@TypeSystemReference(PythonArithmeticTypes.class)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ public abstract class ErrorMessages {
153153
public static final String CODE_OBJ_NO_FREE_VARIABLES = "code object passed to %s may not contain free variables";
154154
public static final String COMPILE_MUST_BE = "compile() mode must be 'exec', 'eval' or 'single'";
155155
public static final String COMPLEX_CANT_TAKE_ARG = "complex() can't take second arg if first is a string";
156+
public static final String COMPLEX_EXPONENTIATION = "complex exponentiation";
157+
public static final String COMPLEX_ZERO_TO_NEGATIVE_POWER = "0.0 to a negative or complex power";
158+
public static final String COMPLEX_MODULO = "complex modulo";
156159
public static final String COMPLEX_RETURNED_NON_COMPLEX = "__complex__ returned non-complex (type %p)";
157160
public static final String COMPLEX_SHOULD_RETURN_COMPLEX = "__complex__ should return a complex object";
158161
public static final String CONTIGUOUS_BUFFER = "contiguous buffer";

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/PythonBuiltinBaseNode.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ protected final PythonObjectFactory factory() {
101101
return objectFactory;
102102
}
103103

104-
private final PRaiseNode getRaiseNode() {
104+
protected final PRaiseNode getRaiseNode() {
105105
if (raiseNode == null) {
106106
CompilerDirectives.transferToInterpreterAndInvalidate();
107107
if (isAdoptable()) {

0 commit comments

Comments
 (0)