Skip to content

Commit 6acf3c2

Browse files
committed
[GR-25532] Fixed performance regression of double to long comparison
PullRequest: graalpython/1188
2 parents b927eb7 + bb15522 commit 6acf3c2

File tree

5 files changed

+169
-43
lines changed

5 files changed

+169
-43
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/complex/ComplexBuiltins.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,9 @@ boolean doComplex(PComplex left, PComplex right) {
596596
}
597597

598598
@Specialization
599-
boolean doComplexInt(PComplex left, long right) {
600-
return left.getImag() == 0 && FloatBuiltins.EqNode.compareDoubleToLong(left.getReal(), right) == 0;
599+
boolean doComplexInt(PComplex left, long right,
600+
@Cached ConditionProfile longFitsToDoubleProfile) {
601+
return left.getImag() == 0 && FloatBuiltins.EqNode.compareDoubleToLong(left.getReal(), right, longFitsToDoubleProfile) == 0;
601602
}
602603

603604
@Specialization
@@ -691,8 +692,9 @@ boolean doComplex(PComplex left, PComplex right) {
691692
}
692693

693694
@Specialization
694-
boolean doComplex(PComplex left, long right) {
695-
return left.getImag() != 0 || FloatBuiltins.EqNode.compareDoubleToLong(left.getReal(), right) != 0;
695+
boolean doComplex(PComplex left, long right,
696+
@Cached ConditionProfile longFitsToDoubleProfile) {
697+
return left.getImag() != 0 || FloatBuiltins.EqNode.compareDoubleToLong(left.getReal(), right, longFitsToDoubleProfile) != 0;
696698
}
697699

698700
@Specialization

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -991,13 +991,25 @@ boolean eqDbDb(double a, double b) {
991991
}
992992

993993
@Specialization
994-
boolean eqDbLn(double a, long b) {
995-
return compareDoubleToLong(a, b) == 0;
994+
boolean doDI(double x, int y) {
995+
return x == y;
996996
}
997997

998998
@Specialization
999-
boolean eqLnDb(long a, double b) {
1000-
return compareDoubleToLong(b, a) == 0;
999+
boolean doID(int x, double y) {
1000+
return x == y;
1001+
}
1002+
1003+
@Specialization
1004+
boolean eqDbLn(double a, long b,
1005+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1006+
return compareDoubleToLong(a, b, longFitsToDoubleProfile) == 0;
1007+
}
1008+
1009+
@Specialization
1010+
boolean eqLnDb(long a, double b,
1011+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1012+
return compareDoubleToLong(b, a, longFitsToDoubleProfile) == 0;
10011013
}
10021014

10031015
@Specialization
@@ -1030,17 +1042,9 @@ PNotImplemented eq(Object a, Object b) {
10301042
}
10311043

10321044
// adapted from CPython's float_richcompare in floatobject.c
1033-
@TruffleBoundary(allowInlining = true)
1034-
public static double compareDoubleToLong(double v, long w) {
1035-
if (!Double.isFinite(v)) {
1036-
return v;
1037-
}
1038-
int vsign = v == 0.0 ? 0 : v < 0.0 ? -1 : 1;
1039-
int wsign = Long.signum(w);
1040-
if (vsign != wsign) {
1041-
return vsign - wsign;
1042-
}
1043-
if (w > -0x1000000000000L && w < 0x1000000000000L) { // w is at most 48 bits
1045+
public static double compareDoubleToLong(double v, long w, ConditionProfile wFitsInDoubleProfile) {
1046+
if (wFitsInDoubleProfile.profile(w > -0x1000000000000L && w < 0x1000000000000L)) {
1047+
// w is at most 48 bits and thus fits into a double without any loss
10441048
return v - w;
10451049
} else {
10461050
return compareUsingBigDecimal(v, PInt.longToBigInteger(w));
@@ -1080,13 +1084,25 @@ boolean neDbDb(double a, double b) {
10801084
}
10811085

10821086
@Specialization
1083-
boolean neDbLn(double a, long b) {
1084-
return EqNode.compareDoubleToLong(a, b) != 0;
1087+
boolean doDI(double x, int y) {
1088+
return x != y;
10851089
}
10861090

10871091
@Specialization
1088-
boolean neLnDb(long a, double b) {
1089-
return EqNode.compareDoubleToLong(b, a) != 0;
1092+
boolean doID(int x, double y) {
1093+
return x != y;
1094+
}
1095+
1096+
@Specialization
1097+
boolean neDbLn(double a, long b,
1098+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1099+
return EqNode.compareDoubleToLong(a, b, longFitsToDoubleProfile) != 0;
1100+
}
1101+
1102+
@Specialization
1103+
boolean neLnDb(long a, double b,
1104+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1105+
return EqNode.compareDoubleToLong(b, a, longFitsToDoubleProfile) != 0;
10901106
}
10911107

10921108
@Specialization
@@ -1129,13 +1145,25 @@ boolean doDD(double x, double y) {
11291145
}
11301146

11311147
@Specialization
1132-
boolean doDL(double x, long y) {
1133-
return EqNode.compareDoubleToLong(x, y) < 0;
1148+
boolean doDI(double x, int y) {
1149+
return x < y;
11341150
}
11351151

11361152
@Specialization
1137-
boolean doLD(long x, double y) {
1138-
return EqNode.compareDoubleToLong(y, x) > 0;
1153+
boolean doID(int x, double y) {
1154+
return x < y;
1155+
}
1156+
1157+
@Specialization
1158+
boolean doDL(double x, long y,
1159+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1160+
return EqNode.compareDoubleToLong(x, y, longFitsToDoubleProfile) < 0;
1161+
}
1162+
1163+
@Specialization
1164+
boolean doLD(long x, double y,
1165+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1166+
return EqNode.compareDoubleToLong(y, x, longFitsToDoubleProfile) > 0;
11391167
}
11401168

11411169
@Specialization
@@ -1199,13 +1227,25 @@ boolean doDD(double x, double y) {
11991227
}
12001228

12011229
@Specialization
1202-
boolean doDL(double x, long y) {
1203-
return EqNode.compareDoubleToLong(x, y) <= 0;
1230+
boolean doDI(double x, int y) {
1231+
return x <= y;
12041232
}
12051233

12061234
@Specialization
1207-
boolean doLD(long x, double y) {
1208-
return EqNode.compareDoubleToLong(y, x) >= 0;
1235+
boolean doID(int x, double y) {
1236+
return x <= y;
1237+
}
1238+
1239+
@Specialization
1240+
boolean doDL(double x, long y,
1241+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1242+
return EqNode.compareDoubleToLong(x, y, longFitsToDoubleProfile) <= 0;
1243+
}
1244+
1245+
@Specialization
1246+
boolean doLD(long x, double y,
1247+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1248+
return EqNode.compareDoubleToLong(y, x, longFitsToDoubleProfile) >= 0;
12091249
}
12101250

12111251
@Specialization
@@ -1269,13 +1309,25 @@ boolean doDD(double x, double y) {
12691309
}
12701310

12711311
@Specialization
1272-
boolean doDL(double x, long y) {
1273-
return EqNode.compareDoubleToLong(x, y) > 0;
1312+
boolean doDI(double x, int y) {
1313+
return x > y;
12741314
}
12751315

12761316
@Specialization
1277-
boolean doLD(long x, double y) {
1278-
return EqNode.compareDoubleToLong(y, x) < 0;
1317+
boolean doID(int x, double y) {
1318+
return x > y;
1319+
}
1320+
1321+
@Specialization
1322+
boolean doDL(double x, long y,
1323+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1324+
return EqNode.compareDoubleToLong(x, y, longFitsToDoubleProfile) > 0;
1325+
}
1326+
1327+
@Specialization
1328+
boolean doLD(long x, double y,
1329+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1330+
return EqNode.compareDoubleToLong(y, x, longFitsToDoubleProfile) < 0;
12791331
}
12801332

12811333
@Specialization
@@ -1339,13 +1391,25 @@ boolean doDD(double x, double y) {
13391391
}
13401392

13411393
@Specialization
1342-
boolean doDL(double x, long y) {
1343-
return EqNode.compareDoubleToLong(x, y) >= 0;
1394+
boolean doDI(double x, int y) {
1395+
return x >= y;
1396+
}
1397+
1398+
@Specialization
1399+
boolean doID(int x, double y) {
1400+
return x >= y;
1401+
}
1402+
1403+
@Specialization
1404+
boolean doDL(double x, long y,
1405+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1406+
return EqNode.compareDoubleToLong(x, y, longFitsToDoubleProfile) >= 0;
13441407
}
13451408

13461409
@Specialization
1347-
boolean doLD(long x, double y) {
1348-
return EqNode.compareDoubleToLong(y, x) <= 0;
1410+
boolean doLD(long x, double y,
1411+
@Shared("longFitsToDouble") @Cached ConditionProfile longFitsToDoubleProfile) {
1412+
return EqNode.compareDoubleToLong(y, x, longFitsToDoubleProfile) <= 0;
13491413
}
13501414

13511415
@Specialization

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallBinaryNode.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@
6464
// Order operations are tried until either a valid result or error: w.op(v,w)[*], v.op(v,w), w.op(v,w)
6565
//
6666
// [*] only when v->ob_type != w->ob_type && w->ob_type is a subclass of v->ob_type
67+
//
68+
// The (long, double) and (double, long) specializations are needed since long->double conversion
69+
// is not always correct (it can lose information). See FloatBuiltins.EqNode.compareDoubleToLong().
70+
// The (int, double) and (double, int) specializations are needed to avoid int->long conversion.
71+
// Although it would produce correct results, the special handling of long to double comparison
72+
// is slower than converting int->double, which is always correct.
6773
public abstract class LookupAndCallBinaryNode extends Node {
6874

6975
public abstract static class NotImplementedHandler extends PNodeWithContext {
@@ -92,6 +98,10 @@ public abstract static class NotImplementedHandler extends PNodeWithContext {
9298

9399
public abstract long executeLong(VirtualFrame frame, Object arg, Object arg2) throws UnexpectedResultException;
94100

101+
public abstract double executeDouble(VirtualFrame frame, int arg, double arg2) throws UnexpectedResultException;
102+
103+
public abstract double executeDouble(VirtualFrame frame, double arg, int arg2) throws UnexpectedResultException;
104+
95105
public abstract double executeDouble(VirtualFrame frame, long arg, double arg2) throws UnexpectedResultException;
96106

97107
public abstract double executeDouble(VirtualFrame frame, double arg, long arg2) throws UnexpectedResultException;
@@ -100,6 +110,10 @@ public abstract static class NotImplementedHandler extends PNodeWithContext {
100110

101111
public abstract boolean executeBool(VirtualFrame frame, int arg, int arg2) throws UnexpectedResultException;
102112

113+
public abstract boolean executeBool(VirtualFrame frame, int arg, double arg2) throws UnexpectedResultException;
114+
115+
public abstract boolean executeBool(VirtualFrame frame, double arg, int arg2) throws UnexpectedResultException;
116+
103117
public abstract boolean executeBool(VirtualFrame frame, long arg, long arg2) throws UnexpectedResultException;
104118

105119
public abstract boolean executeBool(VirtualFrame frame, long arg, double arg2) throws UnexpectedResultException;
@@ -263,6 +277,32 @@ boolean callBoolean(VirtualFrame frame, long left, long right,
263277
}
264278
}
265279

280+
// int, double
281+
282+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
283+
boolean callBoolean(VirtualFrame frame, int left, double right,
284+
@Cached("getBuiltin(right)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
285+
return function.executeBool(frame, left, right);
286+
}
287+
288+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
289+
boolean callBoolean(VirtualFrame frame, double left, int right,
290+
@Cached("getBuiltin(left)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
291+
return function.executeBool(frame, left, right);
292+
}
293+
294+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
295+
double callDouble(VirtualFrame frame, int left, double right,
296+
@Cached("getBuiltin(right)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
297+
return function.executeDouble(frame, left, right);
298+
}
299+
300+
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)
301+
double callDouble(VirtualFrame frame, double left, int right,
302+
@Cached("getBuiltin(left)") PythonBinaryBuiltinNode function) throws UnexpectedResultException {
303+
return function.executeDouble(frame, left, right);
304+
}
305+
266306
// long, double
267307

268308
@Specialization(guards = "function != null", rewriteOn = UnexpectedResultException.class)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/BinaryComparisonNode.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ private static int asInt(boolean left) {
9090
return left ? 1 : 0;
9191
}
9292

93+
private static long asLong(boolean left) {
94+
return left ? 1L : 0L;
95+
}
96+
9397
private static double asDouble(boolean left) {
9498
return left ? 1.0 : 0.0;
9599
}
@@ -121,7 +125,7 @@ boolean doBI(VirtualFrame frame, boolean left, int right) {
121125
@Specialization
122126
boolean doBL(VirtualFrame frame, boolean left, long right) {
123127
try {
124-
return profileCondition(callNode.executeBool(frame, asInt(left), right));
128+
return profileCondition(callNode.executeBool(frame, asLong(left), right));
125129
} catch (UnexpectedResultException e) {
126130
CompilerDirectives.transferToInterpreterAndInvalidate();
127131
throw new IllegalStateException("Comparison on primitive values didn't return a boolean");
@@ -161,7 +165,7 @@ boolean doII(VirtualFrame frame, int left, int right) {
161165
@Specialization
162166
boolean doIL(VirtualFrame frame, int left, long right) {
163167
try {
164-
return profileCondition(callNode.executeBool(frame, left, right));
168+
return profileCondition(callNode.executeBool(frame, (long) left, right));
165169
} catch (UnexpectedResultException e) {
166170
CompilerDirectives.transferToInterpreterAndInvalidate();
167171
throw new IllegalStateException("Comparison on primitive values didn't return a boolean");
@@ -181,7 +185,7 @@ boolean doID(VirtualFrame frame, int left, double right) {
181185
@Specialization
182186
boolean doLB(VirtualFrame frame, long left, boolean right) {
183187
try {
184-
return profileCondition(callNode.executeBool(frame, left, asInt(right)));
188+
return profileCondition(callNode.executeBool(frame, left, asLong(right)));
185189
} catch (UnexpectedResultException e) {
186190
CompilerDirectives.transferToInterpreterAndInvalidate();
187191
throw new IllegalStateException("Comparison on primitive values didn't return a boolean");
@@ -191,7 +195,7 @@ boolean doLB(VirtualFrame frame, long left, boolean right) {
191195
@Specialization
192196
boolean doLI(VirtualFrame frame, long left, int right) {
193197
try {
194-
return profileCondition(callNode.executeBool(frame, left, right));
198+
return profileCondition(callNode.executeBool(frame, left, (long) right));
195199
} catch (UnexpectedResultException e) {
196200
CompilerDirectives.transferToInterpreterAndInvalidate();
197201
throw new IllegalStateException("Comparison on primitive values didn't return a boolean");

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/builtins/PythonBinaryBuiltinNode.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ public int executeInt(VirtualFrame frame, int arg, int arg2) throws UnexpectedRe
6060
return PGuards.expectInteger(execute(frame, arg, arg2));
6161
}
6262

63+
public double executeDouble(VirtualFrame frame, int arg, double arg2) throws UnexpectedResultException {
64+
return PGuards.expectDouble(execute(frame, arg, arg2));
65+
}
66+
67+
public double executeDouble(VirtualFrame frame, double arg, int arg2) throws UnexpectedResultException {
68+
return PGuards.expectDouble(execute(frame, arg, arg2));
69+
}
70+
6371
public long executeLong(VirtualFrame frame, long arg, long arg2) throws UnexpectedResultException {
6472
return PGuards.expectLong(execute(frame, arg, arg2));
6573
}
@@ -80,6 +88,14 @@ public boolean executeBool(VirtualFrame frame, int arg, int arg2) throws Unexpec
8088
return PGuards.expectBoolean(execute(frame, arg, arg2));
8189
}
8290

91+
public boolean executeBool(VirtualFrame frame, int arg, double arg2) throws UnexpectedResultException {
92+
return PGuards.expectBoolean(execute(frame, arg, arg2));
93+
}
94+
95+
public boolean executeBool(VirtualFrame frame, double arg, int arg2) throws UnexpectedResultException {
96+
return PGuards.expectBoolean(execute(frame, arg, arg2));
97+
}
98+
8399
public boolean executeBool(VirtualFrame frame, long arg, long arg2) throws UnexpectedResultException {
84100
return PGuards.expectBoolean(execute(frame, arg, arg2));
85101
}

0 commit comments

Comments
 (0)