Skip to content

Commit 370b1af

Browse files
committed
[GR-23211] Make more test_complex tests pass (fixes comparison, hash, floordiv, mod)
PullRequest: graalpython/1172
2 parents b8dccb7 + 9693dd9 commit 370b1af

File tree

8 files changed

+182
-42
lines changed

8 files changed

+182
-42
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_float.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,23 @@ def test_floatasratio(self):
172172
self.assertRaises(OverflowError, float('-inf').as_integer_ratio)
173173
self.assertRaises(ValueError, float('nan').as_integer_ratio)
174174

175+
def test_compare(self):
176+
i = 2**53 + 1
177+
f = float(i)
178+
self.assertFalse(f == i)
179+
self.assertTrue(f != i)
180+
self.assertTrue(f < i)
181+
self.assertTrue(f <= i)
182+
self.assertFalse(f > i)
183+
self.assertFalse(f >= i)
184+
self.assertFalse(i == f)
185+
self.assertTrue(i != f)
186+
self.assertFalse(i < f)
187+
self.assertFalse(i <= f)
188+
self.assertTrue(i > f)
189+
self.assertTrue(i >= f)
190+
191+
175192
fromHex = float.fromhex
176193

177194
class HexFloatTests(unittest.TestCase):

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,10 @@ private PComplex createComplex(Object cls, double real, double imaginary) {
377377

378378
private PComplex createComplex(Object cls, PComplex value) {
379379
if (isPrimitiveProfile.profileClass(cls, PythonBuiltinClassType.PComplex)) {
380-
return value;
380+
if (isPrimitiveProfile.profileObject(value, PythonBuiltinClassType.PComplex)) {
381+
return value;
382+
}
383+
return factory().createComplex(value.getReal(), value.getImag());
381384
}
382385
return factory().createComplex(cls, value.getReal(), value.getImag());
383386
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/complex/ComplexBuiltins.java

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@
4646
import static com.oracle.graal.python.nodes.SpecialMethodNames.__BOOL__;
4747
import static com.oracle.graal.python.nodes.SpecialMethodNames.__DIVMOD__;
4848
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
49+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__FLOORDIV__;
4950
import static com.oracle.graal.python.nodes.SpecialMethodNames.__FORMAT__;
5051
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GETNEWARGS__;
5152
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GE__;
5253
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GT__;
5354
import static com.oracle.graal.python.nodes.SpecialMethodNames.__HASH__;
5455
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LE__;
5556
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LT__;
57+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__MOD__;
5658
import static com.oracle.graal.python.nodes.SpecialMethodNames.__MUL__;
5759
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NEG__;
5860
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NE__;
@@ -81,7 +83,9 @@
8183
import com.oracle.graal.python.builtins.PythonBuiltins;
8284
import com.oracle.graal.python.builtins.objects.PNone;
8385
import com.oracle.graal.python.builtins.objects.PNotImplemented;
86+
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltins;
8487
import com.oracle.graal.python.builtins.objects.ints.PInt;
88+
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
8589
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
8690
import com.oracle.graal.python.nodes.ErrorMessages;
8791
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
@@ -591,30 +595,17 @@ boolean doComplex(PComplex left, PComplex right) {
591595

592596
@Specialization
593597
boolean doComplexInt(PComplex left, long right) {
594-
if (left.getImag() == 0) {
595-
return left.getReal() == right;
596-
}
597-
return false;
598+
return left.getImag() == 0 && FloatBuiltins.EqNode.compareDoubleToLong(left.getReal(), right) == 0;
598599
}
599600

600601
@Specialization
601602
boolean doComplexInt(PComplex left, PInt right) {
602-
if (left.getImag() == 0) {
603-
try {
604-
return left.getReal() == right.longValueExact();
605-
} catch (ArithmeticException e) {
606-
// do nothing -> return false;
607-
}
608-
}
609-
return false;
603+
return left.getImag() == 0 && FloatBuiltins.EqNode.compareDoubleToLargeInt(left.getReal(), right) == 0;
610604
}
611605

612606
@Specialization
613607
boolean doComplexInt(PComplex left, double right) {
614-
if (left.getImag() == 0) {
615-
return left.getReal() == right;
616-
}
617-
return false;
608+
return left.getImag() == 0 && left.getReal() == right;
618609
}
619610

620611
@SuppressWarnings("unused")
@@ -699,12 +690,12 @@ boolean doComplex(PComplex left, PComplex right) {
699690

700691
@Specialization
701692
boolean doComplex(PComplex left, long right) {
702-
return left.getImag() != 0 || left.getReal() != right;
693+
return left.getImag() != 0 || FloatBuiltins.EqNode.compareDoubleToLong(left.getReal(), right) != 0;
703694
}
704695

705696
@Specialization
706697
boolean doComplex(PComplex left, PInt right) {
707-
return left.getImag() != 0 || left.getReal() != right.doubleValue();
698+
return left.getImag() != 0 || FloatBuiltins.EqNode.compareDoubleToLargeInt(left.getReal(), right) != 0;
708699
}
709700

710701
@Specialization
@@ -836,11 +827,10 @@ abstract static class ImagNode extends PythonBuiltinNode {
836827
@Builtin(name = __HASH__, minNumOfPositionalArgs = 1)
837828
abstract static class HashNode extends PythonUnaryBuiltinNode {
838829
@Specialization
839-
@TruffleBoundary
840-
int hash(PComplex self) {
830+
long hash(PComplex self) {
841831
// just like CPython
842-
int realHash = Double.hashCode(self.getReal());
843-
int imagHash = Double.hashCode(self.getImag());
832+
long realHash = PythonObjectLibrary.hash(self.getReal());
833+
long imagHash = PythonObjectLibrary.hash(self.getImag());
844834
return realHash + PComplex.IMAG_MULTIPLIER * imagHash;
845835
}
846836
}
@@ -853,4 +843,26 @@ PComplex hash(PComplex self) {
853843
return factory().createComplex(self.getReal(), -self.getImag());
854844
}
855845
}
846+
847+
@GenerateNodeFactory
848+
@Builtin(name = __FLOORDIV__, minNumOfPositionalArgs = 2)
849+
@TypeSystemReference(PythonArithmeticTypes.class)
850+
abstract static class FloorDivNode extends PythonBinaryBuiltinNode {
851+
@Specialization
852+
@SuppressWarnings("unused")
853+
Object floorDiv(Object arg) {
854+
throw raise(TypeError, ErrorMessages.CANT_TAKE_FLOOR_OR_MOD_OF_COMPLEX);
855+
}
856+
}
857+
858+
@GenerateNodeFactory
859+
@Builtin(name = __MOD__, minNumOfPositionalArgs = 2)
860+
@TypeSystemReference(PythonArithmeticTypes.class)
861+
abstract static class ModNode extends PythonBinaryBuiltinNode {
862+
@Specialization
863+
@SuppressWarnings("unused")
864+
Object mod(Object arg) {
865+
throw raise(TypeError, ErrorMessages.CANT_TAKE_FLOOR_OR_MOD_OF_COMPLEX);
866+
}
867+
}
856868
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ abstract static class AddNode extends PythonBinaryBuiltinNode {
295295
return left + right;
296296
}
297297

298+
@Specialization
299+
double doLD(long left, double right) {
300+
return left + right;
301+
}
302+
298303
@Specialization
299304
double doDPi(double left, PInt right) {
300305
return left + right.doubleValueWithOverflow(getRaiseNode());
@@ -376,6 +381,11 @@ abstract static class MulNode extends PythonBinaryBuiltinNode {
376381
return left * right;
377382
}
378383

384+
@Specialization
385+
double doLD(long left, double right) {
386+
return left * right;
387+
}
388+
379389
@Specialization
380390
double doDD(double left, double right) {
381391
return left * right;
@@ -999,7 +1009,7 @@ long round(double x, @SuppressWarnings("unused") PNone none,
9991009
@Builtin(name = __EQ__, minNumOfPositionalArgs = 2)
10001010
@GenerateNodeFactory
10011011
@TypeSystemReference(PythonArithmeticTypes.class)
1002-
abstract static class EqNode extends PythonBinaryBuiltinNode {
1012+
public abstract static class EqNode extends PythonBinaryBuiltinNode {
10031013

10041014
@Specialization
10051015
boolean eqDbDb(double a, double b) {
@@ -1008,7 +1018,12 @@ boolean eqDbDb(double a, double b) {
10081018

10091019
@Specialization
10101020
boolean eqDbLn(double a, long b) {
1011-
return a == b;
1021+
return compareDoubleToLong(a, b) == 0;
1022+
}
1023+
1024+
@Specialization
1025+
boolean eqLnDb(long a, double b) {
1026+
return compareDoubleToLong(b, a) == 0;
10121027
}
10131028

10141029
@Specialization
@@ -1041,7 +1056,25 @@ PNotImplemented eq(Object a, Object b) {
10411056
}
10421057

10431058
// adapted from CPython's float_richcompare in floatobject.c
1044-
static double compareDoubleToLargeInt(double v, PInt w) {
1059+
@TruffleBoundary(allowInlining = true)
1060+
public static double compareDoubleToLong(double v, long w) {
1061+
if (!Double.isFinite(v)) {
1062+
return v;
1063+
}
1064+
int vsign = v == 0.0 ? 0 : v < 0.0 ? -1 : 1;
1065+
int wsign = Long.signum(w);
1066+
if (vsign != wsign) {
1067+
return vsign - wsign;
1068+
}
1069+
if (w > -0x1000000000000L && w < 0x1000000000000L) { // w is at most 48 bits
1070+
return v - w;
1071+
} else {
1072+
return compareUsingBigDecimal(v, PInt.longToBigInteger(w));
1073+
}
1074+
}
1075+
1076+
// adapted from CPython's float_richcompare in floatobject.c
1077+
public static double compareDoubleToLargeInt(double v, PInt w) {
10451078
if (!Double.isFinite(v)) {
10461079
return v;
10471080
}
@@ -1074,7 +1107,12 @@ boolean neDbDb(double a, double b) {
10741107

10751108
@Specialization
10761109
boolean neDbLn(double a, long b) {
1077-
return a != b;
1110+
return EqNode.compareDoubleToLong(a, b) != 0;
1111+
}
1112+
1113+
@Specialization
1114+
boolean neLnDb(long a, double b) {
1115+
return EqNode.compareDoubleToLong(b, a) != 0;
10781116
}
10791117

10801118
@Specialization
@@ -1118,7 +1156,12 @@ boolean doDD(double x, double y) {
11181156

11191157
@Specialization
11201158
boolean doDL(double x, long y) {
1121-
return x < y;
1159+
return EqNode.compareDoubleToLong(x, y) < 0;
1160+
}
1161+
1162+
@Specialization
1163+
boolean doLD(long x, double y) {
1164+
return EqNode.compareDoubleToLong(y, x) > 0;
11221165
}
11231166

11241167
@Specialization
@@ -1183,7 +1226,12 @@ boolean doDD(double x, double y) {
11831226

11841227
@Specialization
11851228
boolean doDL(double x, long y) {
1186-
return x <= y;
1229+
return EqNode.compareDoubleToLong(x, y) <= 0;
1230+
}
1231+
1232+
@Specialization
1233+
boolean doLD(long x, double y) {
1234+
return EqNode.compareDoubleToLong(y, x) >= 0;
11871235
}
11881236

11891237
@Specialization
@@ -1248,7 +1296,12 @@ boolean doDD(double x, double y) {
12481296

12491297
@Specialization
12501298
boolean doDL(double x, long y) {
1251-
return x > y;
1299+
return EqNode.compareDoubleToLong(x, y) > 0;
1300+
}
1301+
1302+
@Specialization
1303+
boolean doLD(long x, double y) {
1304+
return EqNode.compareDoubleToLong(y, x) < 0;
12521305
}
12531306

12541307
@Specialization
@@ -1313,7 +1366,12 @@ boolean doDD(double x, double y) {
13131366

13141367
@Specialization
13151368
boolean doDL(double x, long y) {
1316-
return x >= y;
1369+
return EqNode.compareDoubleToLong(x, y) >= 0;
1370+
}
1371+
1372+
@Specialization
1373+
boolean doLD(long x, double y) {
1374+
return EqNode.compareDoubleToLong(y, x) <= 0;
13171375
}
13181376

13191377
@Specialization

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/object/PythonObjectLibrary.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,23 +262,19 @@ public final long hashWithFrame(Object receiver, VirtualFrame frame) {
262262
}
263263
}
264264

265-
@SuppressWarnings("static-method")
266-
public final long hash(boolean receiver) {
265+
public static long hash(boolean receiver) {
267266
return DefaultPythonBooleanExports.hash(receiver);
268267
}
269268

270-
@SuppressWarnings("static-method")
271-
public final long hash(int receiver) {
269+
public static long hash(int receiver) {
272270
return DefaultPythonIntegerExports.hash(receiver);
273271
}
274272

275-
@SuppressWarnings("static-method")
276-
public final long hash(long receiver) {
273+
public static long hash(long receiver) {
277274
return DefaultPythonLongExports.hash(receiver);
278275
}
279276

280-
@SuppressWarnings("static-method")
281-
public final long hash(double receiver) {
277+
public static long hash(double receiver) {
282278
return DefaultPythonDoubleExports.hash(receiver);
283279
}
284280

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallBinaryNode.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,20 @@ public abstract static class NotImplementedHandler extends PNodeWithContext {
9292

9393
public abstract long executeLong(VirtualFrame frame, Object arg, Object arg2) throws UnexpectedResultException;
9494

95+
public abstract double executeDouble(VirtualFrame frame, long arg, double arg2) throws UnexpectedResultException;
96+
97+
public abstract double executeDouble(VirtualFrame frame, double arg, long arg2) throws UnexpectedResultException;
98+
9599
public abstract double executeDouble(VirtualFrame frame, double arg, double arg2) throws UnexpectedResultException;
96100

97101
public abstract boolean executeBool(VirtualFrame frame, int arg, int arg2) throws UnexpectedResultException;
98102

99103
public abstract boolean executeBool(VirtualFrame frame, long arg, long arg2) throws UnexpectedResultException;
100104

105+
public abstract boolean executeBool(VirtualFrame frame, long arg, double arg2) throws UnexpectedResultException;
106+
107+
public abstract boolean executeBool(VirtualFrame frame, double arg, long arg2) throws UnexpectedResultException;
108+
101109
public abstract boolean executeBool(VirtualFrame frame, double arg, double arg2) throws UnexpectedResultException;
102110

103111
public abstract boolean executeBool(VirtualFrame frame, Object arg, Object arg2) throws UnexpectedResultException;
@@ -255,6 +263,32 @@ boolean callBoolean(VirtualFrame frame, long left, long right,
255263
}
256264
}
257265

266+
// long, double
267+
268+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
269+
boolean callBoolean(VirtualFrame frame, long left, double right,
270+
@Cached("getBuiltin(right)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
271+
return function.executeBool(frame, left, right);
272+
}
273+
274+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
275+
boolean callBoolean(VirtualFrame frame, double left, long right,
276+
@Cached("getBuiltin(left)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
277+
return function.executeBool(frame, left, right);
278+
}
279+
280+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
281+
double callDouble(VirtualFrame frame, long left, double right,
282+
@Cached("getBuiltin(right)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
283+
return function.executeDouble(frame, left, right);
284+
}
285+
286+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
287+
double callDouble(VirtualFrame frame, double left, long right,
288+
@Cached("getBuiltin(left)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
289+
return function.executeDouble(frame, left, right);
290+
}
291+
258292
// double, double
259293

260294
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)

0 commit comments

Comments
 (0)