Skip to content

Commit 0b43b09

Browse files
committed
Improve float binops compatibility
1 parent b4050eb commit 0b43b09

File tree

2 files changed

+311
-26
lines changed

2 files changed

+311
-26
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 230 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@
9191
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltinsClinicProviders.FormatNodeClinicProviderGen;
9292
import com.oracle.graal.python.builtins.objects.ints.PInt;
9393
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
94-
import com.oracle.graal.python.lib.CanBeDoubleNode;
9594
import com.oracle.graal.python.lib.PyFloatAsDoubleNode;
95+
import com.oracle.graal.python.lib.PyFloatCheckNode;
96+
import com.oracle.graal.python.lib.PyLongAsDoubleNode;
97+
import com.oracle.graal.python.lib.PyLongCheckNode;
9698
import com.oracle.graal.python.lib.PyObjectHashNode;
9799
import com.oracle.graal.python.nodes.ErrorMessages;
98100
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
@@ -106,6 +108,7 @@
106108
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
107109
import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
108110
import com.oracle.graal.python.nodes.object.GetClassNode;
111+
import com.oracle.graal.python.nodes.object.InlinedGetClassNode;
109112
import com.oracle.graal.python.nodes.object.InlinedGetClassNode.GetPythonObjectClassNode;
110113
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
111114
import com.oracle.graal.python.runtime.exception.PythonErrorType;
@@ -119,12 +122,16 @@
119122
import com.oracle.truffle.api.dsl.Cached.Shared;
120123
import com.oracle.truffle.api.dsl.Fallback;
121124
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
125+
import com.oracle.truffle.api.dsl.GenerateUncached;
122126
import com.oracle.truffle.api.dsl.ImportStatic;
123127
import com.oracle.truffle.api.dsl.NodeFactory;
124128
import com.oracle.truffle.api.dsl.ReportPolymorphism;
125129
import com.oracle.truffle.api.dsl.Specialization;
126130
import com.oracle.truffle.api.dsl.TypeSystemReference;
127131
import com.oracle.truffle.api.frame.VirtualFrame;
132+
import com.oracle.truffle.api.interop.InteropLibrary;
133+
import com.oracle.truffle.api.interop.UnsupportedMessageException;
134+
import com.oracle.truffle.api.library.CachedLibrary;
128135
import com.oracle.truffle.api.nodes.Node;
129136
import com.oracle.truffle.api.nodes.UnexpectedResultException;
130137
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
@@ -290,6 +297,110 @@ static boolean isFloatSubtype(Node inliningTarget, PythonAbstractNativeObject ob
290297
}
291298
}
292299

300+
protected static final int DOUBLE_TYPE = 1;
301+
protected static final int LONG_TYPE = 2;
302+
protected static final int FOREIGN_TYPE = 3;
303+
protected static final int NOT_IMPLEMENTED = 4;
304+
305+
@GenerateUncached
306+
abstract static class ConvertToDoubleCheckNode extends Node {
307+
abstract int execute(Object obj);
308+
309+
@Specialization
310+
static int doDouble(@SuppressWarnings("unused") Double object) {
311+
return DOUBLE_TYPE;
312+
}
313+
314+
@Specialization
315+
static int doInt(@SuppressWarnings("unused") Integer object) {
316+
return LONG_TYPE;
317+
}
318+
319+
@Specialization
320+
static int doLong(@SuppressWarnings("unused") Long object) {
321+
return LONG_TYPE;
322+
}
323+
324+
@Specialization
325+
static int doBoolean(@SuppressWarnings("unused") Boolean object) {
326+
return LONG_TYPE;
327+
}
328+
329+
@Specialization
330+
static int doString(@SuppressWarnings("unused") TruffleString object) {
331+
return NOT_IMPLEMENTED;
332+
}
333+
334+
@Specialization
335+
static int doPBCT(@SuppressWarnings("unused") PythonBuiltinClassType object) {
336+
return NOT_IMPLEMENTED;
337+
}
338+
339+
@Specialization
340+
static int typeCheck(Object obj,
341+
@Cached PyFloatCheckNode floatCheckNode,
342+
@Cached PyLongCheckNode longCheckNode,
343+
@Bind("this") Node inliningTarget,
344+
@Cached InlinedGetClassNode getClassNode,
345+
@CachedLibrary(limit = "3") InteropLibrary interopLibrary) {
346+
if (floatCheckNode.execute(obj)) {
347+
return DOUBLE_TYPE;
348+
}
349+
if (longCheckNode.execute(obj)) {
350+
return LONG_TYPE;
351+
}
352+
Object type = getClassNode.execute(inliningTarget, obj);
353+
if (type == PythonBuiltinClassType.ForeignObject) {
354+
if (interopLibrary.fitsInDouble(obj) || interopLibrary.fitsInLong(obj) || interopLibrary.isBoolean(obj)) {
355+
return FOREIGN_TYPE;
356+
}
357+
}
358+
return NOT_IMPLEMENTED;
359+
}
360+
}
361+
362+
@ImportStatic(FloatBuiltins.class)
363+
@GenerateUncached
364+
abstract static class ConvertToDoubleNode extends Node {
365+
366+
abstract double execute(VirtualFrame frame, int type, Object obj);
367+
368+
@Specialization(guards = "type == DOUBLE_TYPE")
369+
static double doDouble(VirtualFrame frame, @SuppressWarnings("unused") int type, Object obj,
370+
@Cached PyFloatAsDoubleNode asDoubleNode) {
371+
return asDoubleNode.execute(frame, obj);
372+
}
373+
374+
@Specialization(guards = "type == LONG_TYPE")
375+
static double doLong(@SuppressWarnings("unused") int type, Object obj,
376+
@Cached PyLongAsDoubleNode asDoubleNode) {
377+
Object r = asDoubleNode.execute(obj);
378+
assert r != PNotImplemented.NOT_IMPLEMENTED : "Already been checked in ConvertToDoubleCheckNode";
379+
return (double) r;
380+
}
381+
382+
@Specialization(guards = "type == FOREIGN_TYPE")
383+
static double doForeign(@SuppressWarnings("unused") int type, Object obj,
384+
@CachedLibrary(limit = "3") InteropLibrary interopLibrary) {
385+
try {
386+
if (interopLibrary.fitsInDouble(obj)) {
387+
388+
return interopLibrary.asDouble(obj);
389+
}
390+
if (interopLibrary.fitsInLong(obj)) {
391+
return (double) interopLibrary.asLong(obj);
392+
}
393+
if (interopLibrary.isBoolean(obj)) {
394+
return interopLibrary.asBoolean(obj) ? 1.0 : 0.0;
395+
}
396+
} catch (UnsupportedMessageException e) {
397+
throw CompilerDirectives.shouldNotReachHere(e);
398+
}
399+
400+
throw CompilerDirectives.shouldNotReachHere("Should have been checked in ConvertToDoubleCheckNode");
401+
}
402+
}
403+
293404
@Builtin(name = J___RADD__, minNumOfPositionalArgs = 2)
294405
@Builtin(name = J___ADD__, minNumOfPositionalArgs = 2)
295406
@TypeSystemReference(PythonArithmeticTypes.class)
@@ -337,10 +448,25 @@ static Object doDP(VirtualFrame frame, PythonAbstractNativeObject left, double r
337448
}
338449
}
339450

340-
@SuppressWarnings("unused")
341451
@Fallback
342-
static PNotImplemented doGeneric(Object left, Object right) {
343-
return PNotImplemented.NOT_IMPLEMENTED;
452+
static Object doGeneric(VirtualFrame frame, Object left, Object right,
453+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
454+
@Cached ConvertToDoubleNode convertToDoubleNode) {
455+
double leftDouble;
456+
double rightDouble;
457+
int leftType = convertToDoubleCheckNode.execute(left);
458+
if (leftType != NOT_IMPLEMENTED) {
459+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
460+
} else {
461+
return PNotImplemented.NOT_IMPLEMENTED;
462+
}
463+
int rightType = convertToDoubleCheckNode.execute(right);
464+
if (rightType != NOT_IMPLEMENTED) {
465+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
466+
} else {
467+
return PNotImplemented.NOT_IMPLEMENTED;
468+
}
469+
return leftDouble + rightDouble;
344470
}
345471
}
346472

@@ -374,10 +500,25 @@ static double doLD(long left, double right) {
374500
return left.doubleValueWithOverflow(getRaiseNode()) - right;
375501
}
376502

377-
@SuppressWarnings("unused")
378503
@Fallback
379-
static PNotImplemented doGeneric(Object left, Object right) {
380-
return PNotImplemented.NOT_IMPLEMENTED;
504+
static Object doGeneric(VirtualFrame frame, Object left, Object right,
505+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
506+
@Cached ConvertToDoubleNode convertToDoubleNode) {
507+
double leftDouble;
508+
double rightDouble;
509+
int leftType = convertToDoubleCheckNode.execute(left);
510+
if (leftType != NOT_IMPLEMENTED) {
511+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
512+
} else {
513+
return PNotImplemented.NOT_IMPLEMENTED;
514+
}
515+
int rightType = convertToDoubleCheckNode.execute(right);
516+
if (rightType != NOT_IMPLEMENTED) {
517+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
518+
} else {
519+
return PNotImplemented.NOT_IMPLEMENTED;
520+
}
521+
return leftDouble - rightDouble;
381522
}
382523
}
383524

@@ -439,10 +580,25 @@ Object doDP(VirtualFrame frame, PythonAbstractNativeObject left, PInt right,
439580
}
440581
}
441582

442-
@SuppressWarnings("unused")
443583
@Fallback
444-
static PNotImplemented doGeneric(Object left, Object right) {
445-
return PNotImplemented.NOT_IMPLEMENTED;
584+
static Object doGeneric(VirtualFrame frame, Object left, Object right,
585+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
586+
@Cached ConvertToDoubleNode convertToDoubleNode) {
587+
double leftDouble;
588+
double rightDouble;
589+
int leftType = convertToDoubleCheckNode.execute(left);
590+
if (leftType != NOT_IMPLEMENTED) {
591+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
592+
} else {
593+
return PNotImplemented.NOT_IMPLEMENTED;
594+
}
595+
int rightType = convertToDoubleCheckNode.execute(right);
596+
if (rightType != NOT_IMPLEMENTED) {
597+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
598+
} else {
599+
return PNotImplemented.NOT_IMPLEMENTED;
600+
}
601+
return leftDouble * rightDouble;
446602
}
447603
}
448604

@@ -555,21 +711,23 @@ Object doDPiToComplex(VirtualFrame frame, PInt left, double right, @SuppressWarn
555711

556712
@Specialization
557713
Object doGeneric(VirtualFrame frame, Object left, Object right, Object mod,
558-
@Cached CanBeDoubleNode canBeDoubleNode,
559-
@Cached PyFloatAsDoubleNode asDoubleNode,
714+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
715+
@Cached ConvertToDoubleNode convertToDoubleNode,
560716
@Shared("powCall") @Cached("create(Pow)") LookupAndCallTernaryNode callPow) {
561717
if (!(mod instanceof PNone)) {
562718
throw raise(PythonBuiltinClassType.TypeError, ErrorMessages.POW_3RD_ARG_NOT_ALLOWED_UNLESS_INTEGERS);
563719
}
564720
double leftDouble;
565721
double rightDouble;
566-
if (canBeDoubleNode.execute(left)) {
567-
leftDouble = asDoubleNode.execute(frame, left);
722+
int leftType = convertToDoubleCheckNode.execute(left);
723+
if (leftType != NOT_IMPLEMENTED) {
724+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
568725
} else {
569726
return PNotImplemented.NOT_IMPLEMENTED;
570727
}
571-
if (canBeDoubleNode.execute(right)) {
572-
rightDouble = asDoubleNode.execute(frame, right);
728+
int rightType = convertToDoubleCheckNode.execute(right);
729+
if (rightType != NOT_IMPLEMENTED) {
730+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
573731
} else {
574732
return PNotImplemented.NOT_IMPLEMENTED;
575733
}
@@ -653,10 +811,25 @@ PTuple doGenericFloat(VirtualFrame frame, Object left, Object right,
653811
return factory().createTuple(new Object[]{floorDivNode.execute(frame, left, right), modNode.execute(frame, left, right)});
654812
}
655813

656-
@SuppressWarnings("unused")
657814
@Fallback
658-
static PNotImplemented doGeneric(Object left, Object right) {
659-
return PNotImplemented.NOT_IMPLEMENTED;
815+
Object doGeneric(VirtualFrame frame, Object left, Object right,
816+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
817+
@Cached ConvertToDoubleNode convertToDoubleNode) {
818+
double leftDouble;
819+
double rightDouble;
820+
int leftType = convertToDoubleCheckNode.execute(left);
821+
if (leftType != NOT_IMPLEMENTED) {
822+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
823+
} else {
824+
return PNotImplemented.NOT_IMPLEMENTED;
825+
}
826+
int rightType = convertToDoubleCheckNode.execute(right);
827+
if (rightType != NOT_IMPLEMENTED) {
828+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
829+
} else {
830+
return PNotImplemented.NOT_IMPLEMENTED;
831+
}
832+
return doDD(leftDouble, rightDouble);
660833
}
661834

662835
protected static boolean accepts(Object obj) {
@@ -842,10 +1015,25 @@ public abstract static class ModNode extends FloatBinaryBuiltinNode {
8421015
return op(left.doubleValue(), right);
8431016
}
8441017

845-
@SuppressWarnings("unused")
8461018
@Fallback
847-
static PNotImplemented doGeneric(Object right, Object left) {
848-
return PNotImplemented.NOT_IMPLEMENTED;
1019+
Object doGeneric(VirtualFrame frame, Object left, Object right,
1020+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
1021+
@Cached ConvertToDoubleNode convertToDoubleNode) {
1022+
double leftDouble;
1023+
double rightDouble;
1024+
int leftType = convertToDoubleCheckNode.execute(left);
1025+
if (leftType != NOT_IMPLEMENTED) {
1026+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
1027+
} else {
1028+
return PNotImplemented.NOT_IMPLEMENTED;
1029+
}
1030+
int rightType = convertToDoubleCheckNode.execute(right);
1031+
if (rightType != NOT_IMPLEMENTED) {
1032+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
1033+
} else {
1034+
return PNotImplemented.NOT_IMPLEMENTED;
1035+
}
1036+
return doDD(leftDouble, rightDouble);
8491037
}
8501038

8511039
public static double op(double left, double right) {
@@ -908,10 +1096,25 @@ Object doDP(VirtualFrame frame, long left, PythonAbstractNativeObject right,
9081096
}
9091097
}
9101098

911-
@SuppressWarnings("unused")
9121099
@Fallback
913-
static PNotImplemented doGeneric(Object left, Object right) {
914-
return PNotImplemented.NOT_IMPLEMENTED;
1100+
Object doGeneric(VirtualFrame frame, Object left, Object right,
1101+
@Cached ConvertToDoubleCheckNode convertToDoubleCheckNode,
1102+
@Cached ConvertToDoubleNode convertToDoubleNode) {
1103+
double leftDouble;
1104+
double rightDouble;
1105+
int leftType = convertToDoubleCheckNode.execute(left);
1106+
if (leftType != NOT_IMPLEMENTED) {
1107+
leftDouble = convertToDoubleNode.execute(frame, leftType, left);
1108+
} else {
1109+
return PNotImplemented.NOT_IMPLEMENTED;
1110+
}
1111+
int rightType = convertToDoubleCheckNode.execute(right);
1112+
if (rightType != NOT_IMPLEMENTED) {
1113+
rightDouble = convertToDoubleNode.execute(frame, rightType, right);
1114+
} else {
1115+
return PNotImplemented.NOT_IMPLEMENTED;
1116+
}
1117+
return doDD(leftDouble, rightDouble);
9151118
}
9161119
}
9171120

@@ -1579,7 +1782,8 @@ PTuple get(double self,
15791782
throw raise(PythonErrorType.OverflowError, ErrorMessages.CANNOT_CONVERT_S_TO_INT_RATIO, "Infinity");
15801783
}
15811784

1582-
// At the first time find mantissa and exponent. This is functionanlity of Math.frexp
1785+
// At the first time find mantissa and exponent. This is functionanlity of
1786+
// Math.frexp
15831787
// node basically.
15841788
int exponent = 0;
15851789
double mantissa = 0.0;

0 commit comments

Comments
 (0)