Skip to content

Commit b1c0f8e

Browse files
committed
Fix: float-to-int for special float values
1 parent bcdb029 commit b1c0f8e

File tree

4 files changed

+69
-48
lines changed

4 files changed

+69
-48
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_int.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,27 @@ def __index__(self):
349349
assert False, "expected TypeError"
350350

351351

352+
def test_create_int_from_float():
353+
assert int(123.0) == 123
354+
assert int(123.4) == 123
355+
try:
356+
int(float('nan'))
357+
except ValueError:
358+
assert True
359+
else:
360+
assert False, "expected ValueError"
361+
362+
class FloatSub(float):
363+
pass
364+
365+
try:
366+
int(FloatSub(float('nan')))
367+
except ValueError:
368+
assert True
369+
else:
370+
assert False, "expected ValueError"
371+
372+
352373
class FromBytesTests(unittest.TestCase):
353374

354375
def check(self, tests, byteorder, signed=False):

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
import com.oracle.graal.python.builtins.objects.complex.PComplex;
103103
import com.oracle.graal.python.builtins.objects.dict.PDict;
104104
import com.oracle.graal.python.builtins.objects.enumerate.PEnumerate;
105+
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltins;
106+
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltinsFactory;
105107
import com.oracle.graal.python.builtins.objects.floats.PFloat;
106108
import com.oracle.graal.python.builtins.objects.frame.PFrame;
107109
import com.oracle.graal.python.builtins.objects.function.PBuiltinFunction;
@@ -928,7 +930,7 @@ public abstract static class IntNode extends PythonBuiltinNode {
928930
public abstract Object executeWith(VirtualFrame frame, Object cls, Object arg, Object keywordArg);
929931

930932
@TruffleBoundary(transferToInterpreterOnException = false)
931-
private Object stringToIntInternal(String num, int base) {
933+
private Object stringToIntInternal(String num, int base) throws NumberFormatException {
932934
String s = num.replace("_", "");
933935
if ((base >= 2 && base <= 32) || base == 0) {
934936
BigInteger bi = asciiToBigInteger(s, base, false);
@@ -942,7 +944,7 @@ private Object stringToIntInternal(String num, int base) {
942944
}
943945
}
944946

945-
private Object convertToIntInternal(LazyPythonClass cls, Object value, Object number, int base) {
947+
private Object convertToIntInternal(LazyPythonClass cls, Object value, Object number, int base) throws NumberFormatException {
946948
if (value == null) {
947949
throw raise(ValueError, "invalid literal for int() with base %s: %s", base, number);
948950
} else if (value instanceof BigInteger) {
@@ -959,44 +961,22 @@ private Object stringToInt(LazyPythonClass cls, String number, int base) {
959961
return convertToIntInternal(cls, value, number, base);
960962
}
961963

962-
private static Object doubleToIntInternal(double num) {
963-
if (num > Integer.MAX_VALUE || num < Integer.MIN_VALUE) {
964-
return BigInteger.valueOf((long) num);
965-
} else {
966-
return (int) num;
967-
}
968-
}
969-
970-
@TruffleBoundary(transferToInterpreterOnException = false)
971-
private Object toIntInternal(Object number) {
972-
if (number instanceof Integer || number instanceof BigInteger) {
973-
return number;
974-
} else if (number instanceof Boolean) {
975-
return (boolean) number ? 1 : 0;
976-
} else if (number instanceof Double) {
977-
return doubleToIntInternal((Double) number);
978-
} else if (number instanceof String) {
979-
return stringToIntInternal((String) number, 10);
980-
}
981-
return null;
982-
}
983-
984-
private Object toIntInternal(Object number, Object base) {
964+
private Object toIntInternal(Object number, Object base) throws NumberFormatException {
985965
if (number instanceof String && base instanceof Integer) {
986966
return stringToIntInternal((String) number, (Integer) base);
987967
} else {
988968
throw raise(ValueError, "invalid base or val for int()");
989969
}
990970
}
991971

992-
private Object toInt(LazyPythonClass cls, Object number, int base) {
972+
private Object toInt(LazyPythonClass cls, Object number, int base) throws NumberFormatException {
993973
Object value = toIntInternal(number, base);
994974
return convertToIntInternal(cls, value, number, base);
995975
}
996976

997977
// Copied directly from Jython
998978
@TruffleBoundary(transferToInterpreterOnException = false)
999-
private static BigInteger asciiToBigInteger(String str, int possibleBase, boolean isLong) {
979+
private static BigInteger asciiToBigInteger(String str, int possibleBase, boolean isLong) throws NumberFormatException {
1000980
int base = possibleBase;
1001981
int b = 0;
1002982
int e = str.length();
@@ -1090,15 +1070,15 @@ Object parseInt(LazyPythonClass cls, boolean arg, @SuppressWarnings("unused") PN
10901070
}
10911071

10921072
@Specialization(guards = "isNoValue(keywordArg)")
1093-
public Object createInt(LazyPythonClass cls, int arg, @SuppressWarnings("unused") PNone keywordArg) {
1073+
Object createInt(LazyPythonClass cls, int arg, @SuppressWarnings("unused") PNone keywordArg) {
10941074
if (isPrimitiveInt(cls)) {
10951075
return arg;
10961076
}
10971077
return factory().createInt(cls, arg);
10981078
}
10991079

11001080
@Specialization(guards = "isNoValue(keywordArg)")
1101-
public Object createInt(LazyPythonClass cls, long arg, @SuppressWarnings("unused") PNone keywordArg,
1081+
Object createInt(LazyPythonClass cls, long arg, @SuppressWarnings("unused") PNone keywordArg,
11021082
@Cached("createBinaryProfile()") ConditionProfile isIntProfile) {
11031083
if (isPrimitiveInt(cls)) {
11041084
int intValue = (int) arg;
@@ -1112,7 +1092,7 @@ public Object createInt(LazyPythonClass cls, long arg, @SuppressWarnings("unused
11121092
}
11131093

11141094
@Specialization(guards = "isNoValue(keywordArg)")
1115-
public Object createInt(LazyPythonClass cls, PythonNativeVoidPtr arg, @SuppressWarnings("unused") PNone keywordArg) {
1095+
Object createInt(LazyPythonClass cls, PythonNativeVoidPtr arg, @SuppressWarnings("unused") PNone keywordArg) {
11161096
if (isPrimitiveInt(cls)) {
11171097
return arg;
11181098
} else {
@@ -1122,24 +1102,23 @@ public Object createInt(LazyPythonClass cls, PythonNativeVoidPtr arg, @SuppressW
11221102
}
11231103

11241104
@Specialization(guards = "isNoValue(keywordArg)")
1125-
public Object createInt(LazyPythonClass cls, double arg, @SuppressWarnings("unused") PNone keywordArg,
1126-
@Cached("createBinaryProfile()") ConditionProfile isIntProfile) {
1127-
if (isPrimitiveInt(cls) && isIntProfile.profile(arg >= Integer.MIN_VALUE && arg <= Integer.MAX_VALUE)) {
1128-
return (int) arg;
1129-
}
1130-
return factory().createInt(cls, (long) arg);
1105+
Object createInt(VirtualFrame frame, LazyPythonClass cls, double arg, @SuppressWarnings("unused") PNone keywordArg,
1106+
@Cached("createFloatInt()") FloatBuiltins.IntNode intNode,
1107+
@Cached("createGeneric()") CreateIntFromObjectNode createIntFromObjectNode) {
1108+
Object result = intNode.executeWithDouble(arg);
1109+
return createIntFromObjectNode.execute(frame, cls, result);
11311110
}
11321111

11331112
@Specialization
1134-
public Object createInt(LazyPythonClass cls, @SuppressWarnings("unused") PNone none, @SuppressWarnings("unused") PNone keywordArg) {
1113+
Object createInt(LazyPythonClass cls, @SuppressWarnings("unused") PNone none, @SuppressWarnings("unused") PNone keywordArg) {
11351114
if (isPrimitiveInt(cls)) {
11361115
return 0;
11371116
}
11381117
return factory().createInt(cls, 0);
11391118
}
11401119

11411120
@Specialization(guards = "isNoValue(keywordArg)")
1142-
public Object createInt(LazyPythonClass cls, String arg, @SuppressWarnings("unused") PNone keywordArg) {
1121+
Object createInt(LazyPythonClass cls, String arg, @SuppressWarnings("unused") PNone keywordArg) {
11431122
try {
11441123
return stringToInt(cls, arg, 10);
11451124
} catch (NumberFormatException e) {
@@ -1222,13 +1201,13 @@ Object parsePIntError(LazyPythonClass cls, String number, int base) {
12221201
}
12231202

12241203
@Specialization(guards = "!isNoValue(base)", rewriteOn = NumberFormatException.class)
1225-
public Object parsePIntWithBaseObject(LazyPythonClass cls, String number, Object base,
1204+
Object parsePIntWithBaseObject(LazyPythonClass cls, String number, Object base,
12261205
@Cached CastToIndexNode castToIndexNode) {
12271206
return toInt(cls, number, castToIndexNode.execute(base));
12281207
}
12291208

12301209
@Specialization(guards = "!isNoValue(base)", replaces = "parsePIntWithBaseObject")
1231-
public Object createIntError(LazyPythonClass cls, String number, Object base,
1210+
Object createIntError(LazyPythonClass cls, String number, Object base,
12321211
@Cached CastToIndexNode castToIndexNode) {
12331212
try {
12341213
return toInt(cls, number, castToIndexNode.execute(base));
@@ -1244,7 +1223,7 @@ Object fail(LazyPythonClass cls, Object arg, Object keywordArg) {
12441223
}
12451224

12461225
@Specialization(guards = {"isNoValue(keywordArg)", "!isNoValue(obj)", "!isHandledType(obj)"})
1247-
public Object createInt(VirtualFrame frame, LazyPythonClass cls, Object obj, @SuppressWarnings("unused") PNone keywordArg,
1226+
Object createInt(VirtualFrame frame, LazyPythonClass cls, Object obj, @SuppressWarnings("unused") PNone keywordArg,
12481227
@Cached("createGeneric()") CreateIntFromObjectNode createIntFromObjectNode) {
12491228
return createIntFromObjectNode.execute(frame, cls, obj);
12501229
}
@@ -1257,6 +1236,10 @@ protected static CreateIntFromObjectNode createGeneric() {
12571236
return CreateIntFromObjectNode.create(true, () -> LookupAndCallUnaryNode.create(SpecialMethodNames.__INT__));
12581237
}
12591238

1239+
protected static FloatBuiltins.IntNode createFloatInt() {
1240+
return FloatBuiltinsFactory.IntNodeFactory.create();
1241+
}
1242+
12601243
private String toString(VirtualFrame frame, PIBytesLike pByteArray) {
12611244
if (toByteArrayNode == null) {
12621245
CompilerDirectives.transferToInterpreterAndInvalidate();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/floats/FloatBuiltins.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
*/
2626
package com.oracle.graal.python.builtins.objects.floats;
2727

28+
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.ValueError;
2829
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ABS__;
2930
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ADD__;
3031
import static com.oracle.graal.python.nodes.SpecialMethodNames.__BOOL__;
@@ -58,6 +59,7 @@
5859
import static com.oracle.graal.python.nodes.SpecialMethodNames.__TRUNC__;
5960

6061
import java.math.BigDecimal;
62+
import java.math.BigInteger;
6163
import java.math.RoundingMode;
6264
import java.nio.ByteOrder;
6365
import java.util.List;
@@ -233,7 +235,9 @@ boolean bool(double self) {
233235
@GenerateNodeFactory
234236
@ImportStatic(MathGuards.class)
235237
@TypeSystemReference(PythonArithmeticTypes.class)
236-
abstract static class IntNode extends PythonUnaryBuiltinNode {
238+
public abstract static class IntNode extends PythonUnaryBuiltinNode {
239+
240+
public abstract Object executeWithDouble(double self);
237241

238242
@Specialization(guards = "fitInt(self)")
239243
int doIntRange(double self) {
@@ -245,10 +249,23 @@ long doLongRange(double self) {
245249
return (long) self;
246250
}
247251

248-
@Specialization(guards = "!fitLong(self)")
249-
@TruffleBoundary
250-
PInt doGeneric(double self) {
251-
return factory().createInt(BigDecimal.valueOf(self).toBigInteger());
252+
@Specialization(guards = "!fitLong(self)", rewriteOn = NumberFormatException.class)
253+
PInt doDoubleGeneric(double self) {
254+
return factory().createInt(fromDouble(self));
255+
}
256+
257+
@Specialization(guards = "!fitLong(self)", replaces = "doDoubleGeneric")
258+
PInt doDoubleGenericError(double self) {
259+
try {
260+
return factory().createInt(fromDouble(self));
261+
} catch (NumberFormatException e) {
262+
throw raise(ValueError, "cannot convert float %f to integer", self);
263+
}
264+
}
265+
266+
@TruffleBoundary(transferToInterpreterOnException = false)
267+
private static BigInteger fromDouble(double self) {
268+
return BigDecimal.valueOf(self).toBigInteger();
252269
}
253270
}
254271

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/IntBuiltins.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,7 +2361,7 @@ abstract static class TruncNode extends IntNode {
23612361
@Builtin(name = SpecialMethodNames.__INT__, minNumOfPositionalArgs = 1)
23622362
@GenerateNodeFactory
23632363
@TypeSystemReference(PythonArithmeticTypes.class)
2364-
abstract static class IntNode extends PythonBuiltinNode {
2364+
abstract static class IntNode extends PythonUnaryBuiltinNode {
23652365
@Child private GetLazyClassNode getClassNode;
23662366

23672367
protected LazyPythonClass getClass(Object value) {
@@ -2411,7 +2411,7 @@ abstract static class IndexNode extends IntNode {
24112411
@Builtin(name = SpecialMethodNames.__FLOAT__, minNumOfPositionalArgs = 1)
24122412
@GenerateNodeFactory
24132413
@TypeSystemReference(PythonArithmeticTypes.class)
2414-
abstract static class FloatNode extends PythonBuiltinNode {
2414+
abstract static class FloatNode extends PythonUnaryBuiltinNode {
24152415
@Specialization
24162416
double doBoolean(boolean self) {
24172417
return self ? 1.0 : 0.0;

0 commit comments

Comments
 (0)