Skip to content

Commit 06593ac

Browse files
committed
[GR-44783] Improve float binops compatibility
PullRequest: graalpython/2679
2 parents cca1881 + 3b082a5 commit 06593ac

File tree

3 files changed

+203
-36
lines changed

3 files changed

+203
-36
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_object.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,38 @@ def test_int(self):
109109
tester = TestInt()
110110
assert int(tester) == 42
111111

112+
def test_float_binops(self):
113+
TestFloatBinop = CPyExtType("TestFloatBinop",
114+
"""
115+
PyObject* test_float_impl(PyObject* self) {
116+
PyErr_SetString(PyExc_RuntimeError, "Should not call __float__");
117+
return NULL;
118+
}
119+
PyObject* test_add_impl(PyObject* a, PyObject* b) {
120+
return PyLong_FromLong(42);
121+
}
122+
PyObject* test_sub_impl(PyObject* a, PyObject* b) {
123+
return PyLong_FromLong(4242);
124+
}
125+
PyObject* test_mul_impl(PyObject* a, PyObject* b) {
126+
return PyLong_FromLong(424242);
127+
}
128+
PyObject* test_pow_impl(PyObject* a, PyObject* b, PyObject* c) {
129+
return PyLong_FromLong(42424242);
130+
}
131+
""",
132+
nb_float="test_float_impl",
133+
nb_add="test_add_impl",
134+
nb_subtract="test_sub_impl",
135+
nb_multiply="test_mul_impl",
136+
nb_power="test_pow_impl"
137+
)
138+
x = TestFloatBinop()
139+
assert 10.0 + x == 42
140+
assert 10.0 - x == 4242
141+
assert 10.0 * x == 424242
142+
assert 10.0 ** x == 42424242
143+
112144
def test_index(self):
113145
TestIndex = CPyExtType("TestIndex",
114146
"""

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 120 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@
9191
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltinsClinicProviders.FormatNodeClinicProviderGen;
9292
import com.oracle.graal.python.builtins.objects.ints.PInt;
9393
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
94-
import com.oracle.graal.python.lib.CanBeDoubleNode;
95-
import com.oracle.graal.python.lib.PyFloatAsDoubleNode;
9694
import com.oracle.graal.python.lib.PyObjectHashNode;
9795
import com.oracle.graal.python.nodes.ErrorMessages;
9896
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
@@ -108,6 +106,8 @@
108106
import com.oracle.graal.python.nodes.object.GetClassNode;
109107
import com.oracle.graal.python.nodes.object.InlinedGetClassNode.GetPythonObjectClassNode;
110108
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
109+
import com.oracle.graal.python.nodes.util.CannotCastException;
110+
import com.oracle.graal.python.nodes.util.CastToJavaDoubleNode;
111111
import com.oracle.graal.python.runtime.exception.PythonErrorType;
112112
import com.oracle.graal.python.runtime.formatting.FloatFormatter;
113113
import com.oracle.graal.python.runtime.formatting.InternalFormat;
@@ -290,6 +290,16 @@ static boolean isFloatSubtype(Node inliningTarget, PythonAbstractNativeObject ob
290290
}
291291
}
292292

293+
static Object convertToDouble(Object obj,
294+
CastToJavaDoubleNode asDoubleNode) {
295+
try {
296+
return asDoubleNode.execute(obj);
297+
} catch (CannotCastException e) {
298+
// This can only happen to values that are expected to be long.
299+
return PNotImplemented.NOT_IMPLEMENTED;
300+
}
301+
}
302+
293303
@Builtin(name = J___RADD__, minNumOfPositionalArgs = 2)
294304
@Builtin(name = J___ADD__, minNumOfPositionalArgs = 2)
295305
@TypeSystemReference(PythonArithmeticTypes.class)
@@ -337,10 +347,23 @@ static Object doDP(VirtualFrame frame, PythonAbstractNativeObject left, double r
337347
}
338348
}
339349

340-
@SuppressWarnings("unused")
341350
@Fallback
342-
static PNotImplemented doGeneric(Object left, Object right) {
343-
return PNotImplemented.NOT_IMPLEMENTED;
351+
Object doGeneric(Object left, Object right,
352+
@Cached CastToJavaDoubleNode castToJavaDoubleNode) {
353+
354+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
355+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
356+
return PNotImplemented.NOT_IMPLEMENTED;
357+
}
358+
359+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
360+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
361+
return PNotImplemented.NOT_IMPLEMENTED;
362+
}
363+
364+
double leftDouble = (double) objLeft;
365+
double rightDouble = (double) objRight;
366+
return leftDouble + rightDouble;
344367
}
345368
}
346369

@@ -374,10 +397,23 @@ static double doLD(long left, double right) {
374397
return left.doubleValueWithOverflow(getRaiseNode()) - right;
375398
}
376399

377-
@SuppressWarnings("unused")
378400
@Fallback
379-
static PNotImplemented doGeneric(Object left, Object right) {
380-
return PNotImplemented.NOT_IMPLEMENTED;
401+
Object doGeneric(Object left, Object right,
402+
@Cached CastToJavaDoubleNode castToJavaDoubleNode) {
403+
404+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
405+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
406+
return PNotImplemented.NOT_IMPLEMENTED;
407+
}
408+
409+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
410+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
411+
return PNotImplemented.NOT_IMPLEMENTED;
412+
}
413+
414+
double leftDouble = (double) objLeft;
415+
double rightDouble = (double) objRight;
416+
return leftDouble - rightDouble;
381417
}
382418
}
383419

@@ -439,10 +475,23 @@ Object doDP(VirtualFrame frame, PythonAbstractNativeObject left, PInt right,
439475
}
440476
}
441477

442-
@SuppressWarnings("unused")
443478
@Fallback
444-
static PNotImplemented doGeneric(Object left, Object right) {
445-
return PNotImplemented.NOT_IMPLEMENTED;
479+
Object doGeneric(Object left, Object right,
480+
@Cached CastToJavaDoubleNode castToJavaDoubleNode) {
481+
482+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
483+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
484+
return PNotImplemented.NOT_IMPLEMENTED;
485+
}
486+
487+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
488+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
489+
return PNotImplemented.NOT_IMPLEMENTED;
490+
}
491+
492+
double leftDouble = (double) objLeft;
493+
double rightDouble = (double) objRight;
494+
return leftDouble * rightDouble;
446495
}
447496
}
448497

@@ -555,24 +604,24 @@ Object doDPiToComplex(VirtualFrame frame, PInt left, double right, @SuppressWarn
555604

556605
@Specialization
557606
Object doGeneric(VirtualFrame frame, Object left, Object right, Object mod,
558-
@Cached CanBeDoubleNode canBeDoubleNode,
559-
@Cached PyFloatAsDoubleNode asDoubleNode,
607+
@Cached CastToJavaDoubleNode castToJavaDoubleNode,
560608
@Shared("powCall") @Cached("create(Pow)") LookupAndCallTernaryNode callPow) {
561609
if (!(mod instanceof PNone)) {
562610
throw raise(PythonBuiltinClassType.TypeError, ErrorMessages.POW_3RD_ARG_NOT_ALLOWED_UNLESS_INTEGERS);
563611
}
564-
double leftDouble;
565-
double rightDouble;
566-
if (canBeDoubleNode.execute(left)) {
567-
leftDouble = asDoubleNode.execute(frame, left);
568-
} else {
612+
613+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
614+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
569615
return PNotImplemented.NOT_IMPLEMENTED;
570616
}
571-
if (canBeDoubleNode.execute(right)) {
572-
rightDouble = asDoubleNode.execute(frame, right);
573-
} else {
617+
618+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
619+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
574620
return PNotImplemented.NOT_IMPLEMENTED;
575621
}
622+
623+
double leftDouble = (double) objLeft;
624+
double rightDouble = (double) objRight;
576625
return doDDToComplex(frame, leftDouble, rightDouble, PNone.NONE, callPow);
577626
}
578627

@@ -653,10 +702,23 @@ PTuple doGenericFloat(VirtualFrame frame, Object left, Object right,
653702
return factory().createTuple(new Object[]{floorDivNode.execute(frame, left, right), modNode.execute(frame, left, right)});
654703
}
655704

656-
@SuppressWarnings("unused")
657705
@Fallback
658-
static PNotImplemented doGeneric(Object left, Object right) {
659-
return PNotImplemented.NOT_IMPLEMENTED;
706+
Object doGeneric(Object left, Object right,
707+
@Cached CastToJavaDoubleNode castToJavaDoubleNode) {
708+
709+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
710+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
711+
return PNotImplemented.NOT_IMPLEMENTED;
712+
}
713+
714+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
715+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
716+
return PNotImplemented.NOT_IMPLEMENTED;
717+
}
718+
719+
double leftDouble = (double) objLeft;
720+
double rightDouble = (double) objRight;
721+
return doDD(leftDouble, rightDouble);
660722
}
661723

662724
protected static boolean accepts(Object obj) {
@@ -842,10 +904,23 @@ public abstract static class ModNode extends FloatBinaryBuiltinNode {
842904
return op(left.doubleValue(), right);
843905
}
844906

845-
@SuppressWarnings("unused")
846907
@Fallback
847-
static PNotImplemented doGeneric(Object right, Object left) {
848-
return PNotImplemented.NOT_IMPLEMENTED;
908+
Object doGeneric(Object left, Object right,
909+
@Cached CastToJavaDoubleNode castToJavaDoubleNode) {
910+
911+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
912+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
913+
return PNotImplemented.NOT_IMPLEMENTED;
914+
}
915+
916+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
917+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
918+
return PNotImplemented.NOT_IMPLEMENTED;
919+
}
920+
921+
double leftDouble = (double) objLeft;
922+
double rightDouble = (double) objRight;
923+
return doDD(leftDouble, rightDouble);
849924
}
850925

851926
public static double op(double left, double right) {
@@ -908,10 +983,23 @@ Object doDP(VirtualFrame frame, long left, PythonAbstractNativeObject right,
908983
}
909984
}
910985

911-
@SuppressWarnings("unused")
912986
@Fallback
913-
static PNotImplemented doGeneric(Object left, Object right) {
914-
return PNotImplemented.NOT_IMPLEMENTED;
987+
Object doGeneric(Object left, Object right,
988+
@Cached CastToJavaDoubleNode castToJavaDoubleNode) {
989+
990+
Object objLeft = convertToDouble(left, castToJavaDoubleNode);
991+
if (objLeft == PNotImplemented.NOT_IMPLEMENTED) {
992+
return PNotImplemented.NOT_IMPLEMENTED;
993+
}
994+
995+
Object objRight = convertToDouble(right, castToJavaDoubleNode);
996+
if (objRight == PNotImplemented.NOT_IMPLEMENTED) {
997+
return PNotImplemented.NOT_IMPLEMENTED;
998+
}
999+
1000+
double leftDouble = (double) objLeft;
1001+
double rightDouble = (double) objRight;
1002+
return doDD(leftDouble, rightDouble);
9151003
}
9161004
}
9171005

@@ -1579,7 +1667,8 @@ PTuple get(double self,
15791667
throw raise(PythonErrorType.OverflowError, ErrorMessages.CANNOT_CONVERT_S_TO_INT_RATIO, "Infinity");
15801668
}
15811669

1582-
// At the first time find mantissa and exponent. This is functionanlity of Math.frexp
1670+
// At the first time find mantissa and exponent. This is functionanlity of
1671+
// Math.frexp
15831672
// node basically.
15841673
int exponent = 0;
15851674
double mantissa = 0.0;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/util/CastToJavaDoubleNode.java

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,18 @@
5252
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
5353
import com.oracle.graal.python.nodes.object.InlinedGetClassNode.GetPythonObjectClassNode;
5454
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
55+
import com.oracle.truffle.api.CompilerDirectives;
5556
import com.oracle.truffle.api.dsl.Bind;
5657
import com.oracle.truffle.api.dsl.Cached;
5758
import com.oracle.truffle.api.dsl.GenerateUncached;
5859
import com.oracle.truffle.api.dsl.ImportStatic;
5960
import com.oracle.truffle.api.dsl.Specialization;
6061
import com.oracle.truffle.api.dsl.TypeSystemReference;
6162
import com.oracle.truffle.api.interop.InteropLibrary;
63+
import com.oracle.truffle.api.interop.UnsupportedMessageException;
6264
import com.oracle.truffle.api.library.CachedLibrary;
6365
import com.oracle.truffle.api.nodes.Node;
66+
import com.oracle.truffle.api.strings.TruffleString;
6467

6568
/**
6669
* Casts a Python "number" to a Java double without coercion. <b>ATTENTION:</b> If the cast fails,
@@ -74,21 +77,41 @@ public abstract class CastToJavaDoubleNode extends PNodeWithContext {
7477
public abstract double execute(Object x);
7578

7679
@Specialization
77-
static double toDouble(long x) {
80+
static double toDouble(double x) {
7881
return x;
7982
}
8083

8184
@Specialization
82-
static double toDouble(double x) {
85+
static double doBoolean(boolean x) {
86+
return x ? 1.0 : 0.0;
87+
}
88+
89+
@Specialization
90+
static double toInt(int x) {
8391
return x;
8492
}
8593

8694
@Specialization
87-
static double toDouble(PInt x,
95+
static double toLong(long x) {
96+
return x;
97+
}
98+
99+
@Specialization
100+
static double toPInt(PInt x,
88101
@Cached PRaiseNode raise) {
89102
return x.doubleValueWithOverflow(raise);
90103
}
91104

105+
@Specialization
106+
static double doString(@SuppressWarnings("unused") TruffleString object) {
107+
throw CannotCastException.INSTANCE;
108+
}
109+
110+
@Specialization
111+
static double doPBCT(@SuppressWarnings("unused") PythonBuiltinClassType object) {
112+
throw CannotCastException.INSTANCE;
113+
}
114+
92115
@Specialization
93116
static double doNativeObject(PythonAbstractNativeObject x,
94117
@Bind("this") Node node,
@@ -104,8 +127,31 @@ static double doNativeObject(PythonAbstractNativeObject x,
104127
throw CannotCastException.INSTANCE;
105128
}
106129

107-
@Specialization(guards = "!isNumber(x)")
108-
static double doUnsupported(@SuppressWarnings("unused") Object x) {
130+
public static Double doInterop(Object obj,
131+
InteropLibrary interopLibrary) {
132+
try {
133+
if (interopLibrary.fitsInDouble(obj)) {
134+
return interopLibrary.asDouble(obj);
135+
}
136+
if (interopLibrary.fitsInLong(obj)) {
137+
return (double) interopLibrary.asLong(obj);
138+
}
139+
if (interopLibrary.isBoolean(obj)) {
140+
return interopLibrary.asBoolean(obj) ? 1.0 : 0.0;
141+
}
142+
} catch (UnsupportedMessageException e) {
143+
throw CompilerDirectives.shouldNotReachHere(e);
144+
}
145+
return null;
146+
}
147+
148+
@Specialization(guards = "!isNumber(obj)")
149+
static double doGeneric(Object obj,
150+
@CachedLibrary(limit = "3") InteropLibrary interopLibrary) {
151+
Double d = doInterop(obj, interopLibrary);
152+
if (d != null) {
153+
return d;
154+
}
109155
throw CannotCastException.INSTANCE;
110156
}
111157

0 commit comments

Comments
 (0)