Skip to content

Commit 2a7931e

Browse files
committed
Use singleton for float NaN.
1 parent b90cc09 commit 2a7931e

File tree

6 files changed

+104
-56
lines changed

6 files changed

+104
-56
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_float.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def test_hex(self):
127127
f = float(input)
128128
self.assertEqual(toHex(f), expected);
129129

130+
def test_nan(self):
131+
self.assertNotEqual(NAN, NAN)
132+
self.assertNotEqual(float('nan'), float('nan'))
133+
self.assertTrue(NAN is NAN)
134+
130135

131136
fromHex = float.fromhex
132137

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/Python3Core.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
import com.oracle.graal.python.builtins.objects.exception.BaseExceptionBuiltins;
111111
import com.oracle.graal.python.builtins.objects.exception.PBaseException;
112112
import com.oracle.graal.python.builtins.objects.floats.FloatBuiltins;
113+
import com.oracle.graal.python.builtins.objects.floats.PFloat;
113114
import com.oracle.graal.python.builtins.objects.foreign.ForeignObjectBuiltins;
114115
import com.oracle.graal.python.builtins.objects.frame.FrameBuiltins;
115116
import com.oracle.graal.python.builtins.objects.function.AbstractFunctionBuiltins;
@@ -376,6 +377,7 @@ private static final PythonBuiltins[] initializeBuiltins() {
376377

377378
@CompilationFinal private PInt pyTrue;
378379
@CompilationFinal private PInt pyFalse;
380+
@CompilationFinal private PFloat pyNaN;
379381

380382
private final PythonParser parser;
381383

@@ -528,8 +530,9 @@ private void initializeTypes() {
528530
}
529531
}
530532
// now initialize well-known objects
531-
pyTrue = new PInt(lookupType(PythonBuiltinClassType.Boolean), BigInteger.ONE);
532-
pyFalse = new PInt(lookupType(PythonBuiltinClassType.Boolean), BigInteger.ZERO);
533+
pyTrue = new PInt(PythonBuiltinClassType.Boolean, BigInteger.ONE);
534+
pyFalse = new PInt(PythonBuiltinClassType.Boolean, BigInteger.ZERO);
535+
pyNaN = new PFloat(PythonBuiltinClassType.PFloat, Double.NaN);
533536
}
534537

535538
private void populateBuiltins() {
@@ -645,6 +648,10 @@ public PInt getFalse() {
645648
return pyFalse;
646649
}
647650

651+
public PFloat getNaN() {
652+
return pyNaN;
653+
}
654+
648655
public RuntimeException raiseInvalidSyntax(Source source, SourceSection section, String message, Object... arguments) {
649656
Node location = new Node() {
650657
@Override

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ public abstract static class FloatNode extends PythonBuiltinNode {
679679
@Child private BytesNodes.ToBytesNode toByteArrayNode;
680680

681681
private final IsBuiltinClassProfile isPrimitiveProfile = IsBuiltinClassProfile.create();
682+
@CompilationFinal private ConditionProfile isNanProfile;
682683

683684
public abstract Object executeWith(VirtualFrame frame, Object cls, Object arg);
684685

@@ -687,66 +688,65 @@ protected final boolean isPrimitiveFloat(LazyPythonClass cls) {
687688
}
688689

689690
@Specialization(guards = "!isNativeClass(cls)")
690-
public Object floatFromInt(LazyPythonClass cls, int arg) {
691+
Object floatFromInt(LazyPythonClass cls, int arg) {
691692
if (isPrimitiveFloat(cls)) {
692693
return (double) arg;
693694
}
694695
return factory().createFloat(cls, arg);
695696
}
696697

697698
@Specialization(guards = "!isNativeClass(cls)")
698-
public Object floatFromBoolean(LazyPythonClass cls, boolean arg) {
699+
Object floatFromBoolean(LazyPythonClass cls, boolean arg) {
699700
if (isPrimitiveFloat(cls)) {
700701
return arg ? 1d : 0d;
701702
}
702703
return factory().createFloat(cls, arg ? 1d : 0d);
703704
}
704705

705706
@Specialization(guards = "!isNativeClass(cls)")
706-
public Object floatFromLong(LazyPythonClass cls, long arg) {
707+
Object floatFromLong(LazyPythonClass cls, long arg) {
707708
if (isPrimitiveFloat(cls)) {
708709
return (double) arg;
709710
}
710711
return factory().createFloat(cls, arg);
711712
}
712713

713714
@Specialization(guards = "!isNativeClass(cls)")
714-
public Object floatFromPInt(LazyPythonClass cls, PInt arg) {
715+
Object floatFromPInt(LazyPythonClass cls, PInt arg) {
715716
if (isPrimitiveFloat(cls)) {
716717
return arg.doubleValue();
717718
}
718719
return factory().createFloat(cls, arg.doubleValue());
719720
}
720721

721722
@Specialization(guards = "!isNativeClass(cls)")
722-
public Object floatFromFloat(LazyPythonClass cls, double arg) {
723+
Object floatFromDouble(LazyPythonClass cls, double arg) {
723724
if (isPrimitiveFloat(cls)) {
724725
return arg;
725726
}
726-
return factory().createFloat(cls, arg);
727+
return factoryCreateFloat(cls, arg);
727728
}
728729

729730
@Specialization(guards = "!isNativeClass(cls)")
730-
public Object floatFromString(LazyPythonClass cls, String arg) {
731+
Object floatFromString(LazyPythonClass cls, String arg) {
731732
double value = convertStringToDouble(arg);
732733
if (isPrimitiveFloat(cls)) {
733734
return value;
734735
}
735-
return factory().createFloat(cls, value);
736+
return factoryCreateFloat(cls, value);
736737
}
737738

738739
@Specialization(guards = "!isNativeClass(cls)")
739-
public Object floatFromBytes(VirtualFrame frame, LazyPythonClass cls, PIBytesLike arg) {
740+
Object floatFromBytes(VirtualFrame frame, LazyPythonClass cls, PIBytesLike arg) {
740741
double value = convertBytesToDouble(frame, arg);
741742
if (isPrimitiveFloat(cls)) {
742743
return value;
743744
}
744-
return factory().createFloat(cls, value);
745+
return factoryCreateFloat(cls, value);
745746
}
746747

747748
private double convertBytesToDouble(VirtualFrame frame, PIBytesLike arg) {
748-
double value = convertStringToDouble(createString(getByteArray(frame, arg)));
749-
return value;
749+
return convertStringToDouble(createString(getByteArray(frame, arg)));
750750
}
751751

752752
@TruffleBoundary
@@ -796,7 +796,7 @@ private double convertStringToDouble(String str) {
796796
}
797797

798798
@Specialization(guards = "!isNativeClass(cls)")
799-
public Object floatFromNone(LazyPythonClass cls, @SuppressWarnings("unused") PNone arg) {
799+
Object floatFromNone(LazyPythonClass cls, @SuppressWarnings("unused") PNone arg) {
800800
if (isPrimitiveFloat(cls)) {
801801
return 0.0;
802802
}
@@ -836,7 +836,7 @@ public Object floatFromNone(LazyPythonClass cls, @SuppressWarnings("unused") PNo
836836
Object doPythonObject(VirtualFrame frame, LazyPythonClass cls, Object obj,
837837
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode callFloatNode,
838838
@Cached("create()") BranchProfile gotException) {
839-
return floatFromFloat(cls, doubleFromObject(frame, cls, obj, callFloatNode, gotException));
839+
return floatFromDouble(cls, doubleFromObject(frame, cls, obj, callFloatNode, gotException));
840840
}
841841

842842
// logic similar to float_subtype_new(PyTypeObject *type, PyObject *x) from CPython
@@ -854,7 +854,7 @@ Object doPythonObject(VirtualFrame frame, PythonNativeClass cls, Object obj,
854854

855855
@Fallback
856856
@TruffleBoundary
857-
public Object floatFromObject(@SuppressWarnings("unused") Object cls, Object arg) {
857+
Object floatFromObject(@SuppressWarnings("unused") Object cls, Object arg) {
858858
throw raise(TypeError, "can't convert %s to float", arg.getClass().getSimpleName());
859859
}
860860

@@ -869,6 +869,21 @@ private byte[] getByteArray(VirtualFrame frame, PIBytesLike pByteArray) {
869869
}
870870
return toByteArrayNode.execute(frame, pByteArray);
871871
}
872+
873+
private PFloat factoryCreateFloat(LazyPythonClass cls, double arg) {
874+
if (isNaN(arg)) {
875+
return getCore().getNaN();
876+
}
877+
return factory().createFloat(cls, arg);
878+
}
879+
880+
private boolean isNaN(double d) {
881+
if (isNanProfile == null) {
882+
CompilerDirectives.transferToInterpreterAndInvalidate();
883+
isNanProfile = ConditionProfile.createBinaryProfile();
884+
}
885+
return isNanProfile.profile(Double.isNaN(d));
886+
}
872887
}
873888

874889
// frozenset([iterable])

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/CExtNodes.java

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.PointerCompareNodeGen;
7171
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.ToJavaNodeFactory.ToJavaCachedNodeGen;
7272
import com.oracle.graal.python.builtins.objects.cext.DynamicObjectNativeWrapper.PrimitiveNativeWrapper;
73-
import com.oracle.graal.python.builtins.objects.cext.DynamicObjectNativeWrapper.PythonObjectNativeWrapper;
7473
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
7574
import com.oracle.graal.python.builtins.objects.floats.PFloat;
7675
import com.oracle.graal.python.builtins.objects.function.PFunction;
@@ -355,11 +354,28 @@ Object doLong(long l) {
355354
return DynamicObjectNativeWrapper.PrimitiveNativeWrapper.createLong(l);
356355
}
357356

358-
@Specialization
357+
@Specialization(guards = "!isNaN(d)")
359358
Object doDouble(double d) {
360359
return DynamicObjectNativeWrapper.PrimitiveNativeWrapper.createDouble(d);
361360
}
362361

362+
@Specialization(guards = "isNaN(d)")
363+
Object doDouble(@SuppressWarnings("unused") double d,
364+
@CachedContext(PythonLanguage.class) PythonContext context,
365+
@Cached("createCountingProfile()") ConditionProfile noWrapperProfile) {
366+
PFloat boxed = context.getCore().getNaN();
367+
DynamicObjectNativeWrapper nativeWrapper = boxed.getNativeWrapper();
368+
// Use a counting profile since we should enter the branch just once per context.
369+
if (noWrapperProfile.profile(nativeWrapper == null)) {
370+
// This deliberately uses 'CompilerDirectives.transferToInterpreter()' because this
371+
// code will happen just once per context.
372+
CompilerDirectives.transferToInterpreter();
373+
nativeWrapper = DynamicObjectNativeWrapper.PrimitiveNativeWrapper.createDouble(Double.NaN);
374+
boxed.setNativeWrapper(nativeWrapper);
375+
}
376+
return nativeWrapper;
377+
}
378+
363379
@Specialization
364380
Object doNativeClass(PythonAbstractNativeObject nativeClass) {
365381
return nativeClass.getPtr();
@@ -413,43 +429,17 @@ protected static PythonClassNativeWrapper wrapNativeClass(PythonManagedClass obj
413429
return PythonClassNativeWrapper.wrap(object, GetNameNode.doSlowPath(object));
414430
}
415431

416-
// TODO(fa): Workaround for DSL bug: did not import factory at users
432+
protected static boolean isNaN(double d) {
433+
return Double.isNaN(d);
434+
}
435+
417436
public static ToSulongNode create() {
418437
return CExtNodesFactory.ToSulongNodeGen.create();
419438
}
420439

421-
// TODO(fa): Workaround for DSL bug: did not import factory at users
422440
public static ToSulongNode getUncached() {
423441
return CExtNodesFactory.ToSulongNodeGen.getUncached();
424442
}
425-
426-
@TruffleBoundary
427-
public static Object doSlowPath(Object o) {
428-
if (o instanceof String) {
429-
return PythonObjectNativeWrapper.wrapSlowPath(PythonLanguage.getCore().factory().createString((String) o));
430-
} else if (o instanceof Integer) {
431-
return PrimitiveNativeWrapper.createInt((Integer) o);
432-
} else if (o instanceof Long) {
433-
return PrimitiveNativeWrapper.createLong((Long) o);
434-
} else if (o instanceof Double) {
435-
return PrimitiveNativeWrapper.createDouble((Double) o);
436-
} else if (PythonNativeClass.isInstance(o)) {
437-
return ((PythonNativeClass) o).getPtr();
438-
} else if (PythonNativeObject.isInstance(o)) {
439-
return PythonNativeObject.cast(o).getPtr();
440-
} else if (o instanceof PythonNativeNull) {
441-
return ((PythonNativeNull) o).getPtr();
442-
} else if (o instanceof PythonManagedClass) {
443-
return wrapNativeClass((PythonManagedClass) o);
444-
} else if (o instanceof PythonAbstractObject) {
445-
assert !PGuards.isClass(o);
446-
return PythonObjectNativeWrapper.wrapSlowPath((PythonAbstractObject) o);
447-
} else if (PGuards.isForeignObject(o)) {
448-
return TruffleObjectNativeWrapper.wrap((TruffleObject) o);
449-
}
450-
assert o != null : "Java 'null' cannot be a Sulong value";
451-
return o;
452-
}
453443
}
454444

455445
// -----------------------------------------------------------------------------------------------------------------
@@ -610,9 +600,13 @@ public abstract static class MaterializeDelegateNode extends CExtBaseNode {
610600
@Specialization(guards = {"!isMaterialized(object)", "object.isBool()"})
611601
PInt doBoolNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
612602
@CachedContext(PythonLanguage.class) PythonContext context) {
603+
// Special case for True and False: use singletons
613604
PythonCore core = context.getCore();
614605
PInt materializedInt = object.getBool() ? core.getTrue() : core.getFalse();
615606
object.setMaterializedObject(materializedInt);
607+
608+
// If the singleton already has a native wrapper, we may need to update the pointer
609+
// of wrapper 'object' since the native could code see the same pointer.
616610
if (materializedInt.getNativeWrapper() != null) {
617611
object.setNativePointer(materializedInt.getNativeWrapper().getNativePointer());
618612
} else {
@@ -623,7 +617,7 @@ PInt doBoolNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper objec
623617

624618
@Specialization(guards = {"!isMaterialized(object)", "object.isByte()"})
625619
PInt doByteNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
626-
@Cached PythonObjectFactory factory) {
620+
@Shared("factory") @Cached PythonObjectFactory factory) {
627621
PInt materializedInt = factory.createInt(object.getByte());
628622
object.setMaterializedObject(materializedInt);
629623
materializedInt.setNativeWrapper(object);
@@ -632,7 +626,7 @@ PInt doByteNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper objec
632626

633627
@Specialization(guards = {"!isMaterialized(object)", "object.isInt()"})
634628
PInt doIntNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
635-
@Cached PythonObjectFactory factory) {
629+
@Shared("factory") @Cached PythonObjectFactory factory) {
636630
PInt materializedInt = factory.createInt(object.getInt());
637631
object.setMaterializedObject(materializedInt);
638632
materializedInt.setNativeWrapper(object);
@@ -641,22 +635,39 @@ PInt doIntNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object
641635

642636
@Specialization(guards = {"!isMaterialized(object)", "object.isLong()"})
643637
PInt doLongNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
644-
@Cached PythonObjectFactory factory) {
638+
@Shared("factory") @Cached PythonObjectFactory factory) {
645639
PInt materializedInt = factory.createInt(object.getLong());
646640
object.setMaterializedObject(materializedInt);
647641
materializedInt.setNativeWrapper(object);
648642
return materializedInt;
649643
}
650644

651-
@Specialization(guards = {"!isMaterialized(object)", "object.isDouble()"})
645+
@Specialization(guards = {"!isMaterialized(object)", "object.isDouble()", "!isNaN(object)"})
652646
PFloat doDoubleNativeWrapper(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
653-
@Cached PythonObjectFactory factory) {
647+
@Shared("factory") @Cached PythonObjectFactory factory) {
654648
PFloat materializedInt = factory.createFloat(object.getDouble());
655-
object.setMaterializedObject(materializedInt);
656649
materializedInt.setNativeWrapper(object);
650+
object.setMaterializedObject(materializedInt);
657651
return materializedInt;
658652
}
659653

654+
@Specialization(guards = {"!isMaterialized(object)", "object.isDouble()", "isNaN(object)"})
655+
PFloat doDoubleNativeWrapperNaN(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
656+
@CachedContext(PythonLanguage.class) PythonContext context) {
657+
// Special case for double NaN: use singleton
658+
PFloat materializedFloat = context.getCore().getNaN();
659+
object.setMaterializedObject(materializedFloat);
660+
661+
// If the NaN singleton already has a native wrapper, we may need to update the pointer
662+
// of wrapper 'object' since the native code should see the same pointer.
663+
if (materializedFloat.getNativeWrapper() != null) {
664+
object.setNativePointer(materializedFloat.getNativeWrapper().getNativePointer());
665+
} else {
666+
materializedFloat.setNativeWrapper(object);
667+
}
668+
return materializedFloat;
669+
}
670+
660671
@Specialization(guards = {"object.getClass() == cachedClass", "isMaterialized(object)"})
661672
Object doMaterialized(DynamicObjectNativeWrapper.PrimitiveNativeWrapper object,
662673
@SuppressWarnings("unused") @Cached("object.getClass()") Class<? extends DynamicObjectNativeWrapper.PrimitiveNativeWrapper> cachedClass) {
@@ -677,6 +688,11 @@ Object doNativeWrapperGeneric(PythonNativeWrapper object) {
677688
protected static boolean isPrimitiveNativeWrapper(PythonNativeWrapper object) {
678689
return object instanceof DynamicObjectNativeWrapper.PrimitiveNativeWrapper;
679690
}
691+
692+
protected static boolean isNaN(PrimitiveNativeWrapper object) {
693+
assert object.isDouble();
694+
return Double.isNaN(object.getDouble());
695+
}
680696
}
681697

682698
// -----------------------------------------------------------------------------------------------------------------

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/IsExpressionNode.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ boolean doDL(double left, long right) {
174174

175175
@Specialization
176176
boolean doDD(double left, double right) {
177-
return left == right;
177+
// n.b. we simulate that the primitive NaN is a singleton; this is required to make
178+
// 'nan = float("nan"); nan is nan' work
179+
return left == right || (Double.isNaN(left) && Double.isNaN(right));
178180
}
179181

180182
@Specialization

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/PythonCore.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import com.oracle.graal.python.PythonLanguage;
2929
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
30+
import com.oracle.graal.python.builtins.objects.floats.PFloat;
3031
import com.oracle.graal.python.builtins.objects.ints.PInt;
3132
import com.oracle.graal.python.builtins.objects.module.PythonModule;
3233
import com.oracle.graal.python.builtins.objects.type.PythonBuiltinClass;
@@ -81,6 +82,8 @@ public interface PythonCore extends ParserErrorCallback {
8182

8283
public PInt getFalse();
8384

85+
public PFloat getNaN();
86+
8487
public PythonModule getBuiltins();
8588

8689
static void writeInfo(String message) {

0 commit comments

Comments
 (0)