Skip to content

Commit a3d56c1

Browse files
committed
[GR-44691] Improve support of (native) float subclasses.
PullRequest: graalpython/2745
2 parents 46ce27e + 18075c0 commit a3d56c1

File tree

9 files changed

+139
-54
lines changed

9 files changed

+139
-54
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_object.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,41 @@ def ignore_test_float_subclass(self):
473473
nb_add="fp_add",
474474
tp_new="fp_tpnew",
475475
post_ready_code="testFloatSubclassPtr = &TestFloatSubclassType; Py_INCREF(testFloatSubclassPtr);"
476-
)
476+
)
477477
tester = TestFloatSubclass(41.0)
478478
res = tester + 1
479479
assert res == 42.0, "expected 42.0 but was %s" % res
480480
assert hash(tester) != 0
481481

482+
def test_float_subclass2(self):
483+
NativeFloatSubclass = CPyExtType(
484+
"NativeFloatSubclass",
485+
"""
486+
static PyObject* fp_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
487+
PyObject *result = PyFloat_Type.tp_new(type, args, kwds);
488+
NativeFloatSubclassObject *nfs = (NativeFloatSubclassObject *)result;
489+
nfs->myobval = PyFloat_AsDouble(result);
490+
return result;
491+
}
492+
493+
static PyObject* fp_tp_repr(PyObject* self) {
494+
NativeFloatSubclassObject *nfs = (NativeFloatSubclassObject *)self;
495+
return PyUnicode_FromFormat("native %S", PyFloat_FromDouble(nfs->myobval));
496+
}
497+
""",
498+
struct_base="PyFloatObject base",
499+
cmembers="double myobval;",
500+
tp_base="&PyFloat_Type",
501+
tp_new="fp_tp_new",
502+
tp_repr="fp_tp_repr"
503+
)
504+
class MyFloat(NativeFloatSubclass):
505+
pass
506+
assert MyFloat() == 0.0
507+
assert MyFloat(123.0) == 123.0
508+
assert repr(MyFloat()) == "native 0.0"
509+
assert repr(MyFloat(123.0)) == "native 123.0"
510+
482511
def test_custom_basicsize(self):
483512
TestCustomBasicsize = CPyExtType("TestCustomBasicsize",
484513
'''

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -968,59 +968,54 @@ abstract static class FloatNode extends PythonBinaryBuiltinNode {
968968
// Used for the recursive call
969969
protected abstract double executeDouble(VirtualFrame frame, PythonBuiltinClassType cls, Object arg) throws UnexpectedResultException;
970970

971-
protected final boolean isPrimitiveFloat(Node inliningTarget, Object cls, InlineIsBuiltinClassProfile isPrimitiveProfile) {
972-
return isPrimitiveProfile.profileIsBuiltinClass(inliningTarget, cls, PythonBuiltinClassType.PFloat);
973-
}
974-
975971
@Specialization(guards = "isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", limit = "1")
976-
double floatFromDouble(@SuppressWarnings("unused") Object cls, double arg,
977-
@Bind("this") Node inliningTarget,
978-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
972+
static double floatFromDouble(@SuppressWarnings("unused") Object cls, double arg,
973+
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
974+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
979975
return arg;
980976
}
981977

982978
@Specialization(guards = "isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", limit = "1")
983-
double floatFromInt(@SuppressWarnings("unused") Object cls, int arg,
984-
@Bind("this") Node inliningTarget,
985-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
979+
static double floatFromInt(@SuppressWarnings("unused") Object cls, int arg,
980+
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
981+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
986982
return arg;
987983
}
988984

989985
@Specialization(guards = "isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", limit = "1")
990-
double floatFromLong(@SuppressWarnings("unused") Object cls, long arg,
991-
@Bind("this") Node inliningTarget,
992-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
986+
static double floatFromLong(@SuppressWarnings("unused") Object cls, long arg,
987+
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
988+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
993989
return arg;
994990
}
995991

996992
@Specialization(guards = "isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", limit = "1")
997-
double floatFromBoolean(@SuppressWarnings("unused") Object cls, boolean arg,
998-
@Bind("this") Node inliningTarget,
999-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
993+
static double floatFromBoolean(@SuppressWarnings("unused") Object cls, boolean arg,
994+
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
995+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
1000996
return arg ? 1d : 0d;
1001997
}
1002998

1003999
@Specialization(guards = "isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", limit = "1")
1004-
double floatFromString(VirtualFrame frame, @SuppressWarnings("unused") Object cls, TruffleString obj,
1005-
@Bind("this") Node inliningTarget,
1006-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile,
1000+
static double floatFromString(VirtualFrame frame, @SuppressWarnings("unused") Object cls, TruffleString obj,
1001+
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
1002+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile,
10071003
@Shared("fromString") @Cached PyFloatFromString fromString) {
10081004
return fromString.execute(frame, obj);
10091005
}
10101006

10111007
@Specialization(guards = {"isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", "isNoValue(obj)"}, limit = "1")
1012-
double floatFromNoValue(@SuppressWarnings("unused") Object cls, @SuppressWarnings("unused") PNone obj,
1013-
@Bind("this") Node inliningTarget,
1014-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
1008+
static double floatFromNoValue(@SuppressWarnings("unused") Object cls, @SuppressWarnings("unused") PNone obj,
1009+
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
1010+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
10151011
return 0.0;
10161012
}
10171013

10181014
@Specialization(guards = {"isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", "!isNoValue(obj)"}, //
10191015
replaces = "floatFromString", limit = "1")
1020-
@SuppressWarnings("truffle-static-method")
1021-
double floatFromObject(VirtualFrame frame, @SuppressWarnings("unused") Object cls, Object obj,
1016+
static double floatFromObject(VirtualFrame frame, @SuppressWarnings("unused") Object cls, Object obj,
10221017
@Bind("this") Node inliningTarget,
1023-
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile,
1018+
@SuppressWarnings("unused") @Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile,
10241019
@Cached IsBuiltinObjectProfile stringProfile,
10251020
@Shared("fromString") @Cached PyFloatFromString fromString,
10261021
@Cached PyNumberFloatNode pyNumberFloat) {
@@ -1030,7 +1025,15 @@ protected final boolean isPrimitiveFloat(Node inliningTarget, Object cls, Inline
10301025
return pyNumberFloat.execute(frame, obj);
10311026
}
10321027

1033-
@Specialization(guards = {"!isNativeClass(cls)", "!isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)"}, //
1028+
@Specialization(guards = {"!needsNativeAllocation(cls)", "!isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", "isNoValue(obj)"}, //
1029+
limit = "1")
1030+
Object floatFromNoneManagedSubclass(Object cls, PNone obj,
1031+
@Bind("this") Node inliningTarget,
1032+
@Shared("isFloat") @Cached InlineIsBuiltinClassProfile isPrimitiveFloatProfile) {
1033+
return factory().createFloat(cls, floatFromNoValue(cls, obj, inliningTarget, isPrimitiveFloatProfile));
1034+
}
1035+
1036+
@Specialization(guards = {"!needsNativeAllocation(cls)", "!isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)"}, //
10341037
limit = "1")
10351038
Object floatFromObjectManagedSubclass(VirtualFrame frame, Object cls, Object obj,
10361039
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
@@ -1046,8 +1049,13 @@ Object floatFromObjectManagedSubclass(VirtualFrame frame, Object cls, Object obj
10461049
// logic similar to float_subtype_new(PyTypeObject *type, PyObject *x) from CPython
10471050
// floatobject.c we have to first create a temporary float, then fill it into
10481051
// a natively allocated subtype structure
1049-
@Specialization(guards = "isSubtypeOfFloat(frame, isSubtype, cls)", limit = "1")
1050-
static Object floatFromObjectNativeSubclass(VirtualFrame frame, PythonNativeClass cls, Object obj,
1052+
@Specialization(guards = { //
1053+
"needsNativeAllocation(cls)", //
1054+
"!isPrimitiveFloat(this, cls, isPrimitiveFloatProfile)", //
1055+
"isSubtypeOfFloat(frame, isSubtype, cls)"}, limit = "1")
1056+
static Object floatFromObjectNativeSubclass(VirtualFrame frame, Object cls, Object obj,
1057+
@Bind("this") @SuppressWarnings("unused") Node inliningTarget,
1058+
@Shared("isFloat") @Cached @SuppressWarnings("unused") InlineIsBuiltinClassProfile isPrimitiveFloatProfile,
10511059
@Cached @SuppressWarnings("unused") IsSubtypeNode isSubtype,
10521060
@Cached CExtNodes.FloatSubtypeNew subtypeNew,
10531061
@Shared @Cached FloatNode recursiveCallNode) {
@@ -1058,7 +1066,11 @@ static Object floatFromObjectNativeSubclass(VirtualFrame frame, PythonNativeClas
10581066
}
10591067
}
10601068

1061-
protected static boolean isSubtypeOfFloat(VirtualFrame frame, IsSubtypeNode isSubtypeNode, PythonNativeClass cls) {
1069+
protected final boolean isPrimitiveFloat(Node inliningTarget, Object cls, InlineIsBuiltinClassProfile isPrimitiveProfile) {
1070+
return isPrimitiveProfile.profileIsBuiltinClass(inliningTarget, cls, PythonBuiltinClassType.PFloat);
1071+
}
1072+
1073+
protected static boolean isSubtypeOfFloat(VirtualFrame frame, IsSubtypeNode isSubtypeNode, Object cls) {
10621074
return isSubtypeNode.execute(frame, cls, PythonBuiltinClassType.PFloat);
10631075
}
10641076
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/CDataObject.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import com.oracle.graal.python.builtins.objects.cext.capi.PythonNativeWrapper;
4646
import com.oracle.graal.python.builtins.objects.cext.capi.transitions.CApiTransitions;
4747
import com.oracle.graal.python.builtins.objects.object.PythonBuiltinObject;
48+
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
4849
import com.oracle.graal.python.util.PythonUtils;
4950
import com.oracle.truffle.api.CompilerDirectives;
5051
import com.oracle.truffle.api.dsl.Bind;
@@ -58,7 +59,6 @@
5859
import com.oracle.truffle.api.nodes.Node;
5960
import com.oracle.truffle.api.object.Shape;
6061
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
61-
import com.oracle.truffle.api.strings.TruffleString;
6262
import com.oracle.truffle.llvm.spi.NativeTypeLibrary;
6363

6464
@ExportLibrary(PythonBufferAcquireLibrary.class)
@@ -221,7 +221,7 @@ public CDataObjectWrapper(StgDictObject stgDict, byte[] storage) {
221221
this.nativePointer = null;
222222
}
223223

224-
private int getIndex(String field, TruffleString.ToJavaStringNode toJavaStringNode) {
224+
private int getIndex(String field, CastToJavaStringNode toJavaStringNode) {
225225
String[] fields = getMembers(true, toJavaStringNode);
226226
for (int i = 0; i < fields.length; i++) {
227227
if (fields[i].equals(field)) {
@@ -238,7 +238,7 @@ boolean hasMembers() {
238238

239239
@ExportMessage
240240
String[] getMembers(@SuppressWarnings("unused") boolean includeInternal,
241-
@Shared("ts2js") @Cached TruffleString.ToJavaStringNode toJavaStringNode) {
241+
@Shared @Cached CastToJavaStringNode toJavaStringNode) {
242242
if (members == null) {
243243
members = new String[this.stgDict.fieldsNames.length];
244244
for (int i = 0; i < this.stgDict.fieldsNames.length; i++) {
@@ -250,13 +250,13 @@ String[] getMembers(@SuppressWarnings("unused") boolean includeInternal,
250250

251251
@ExportMessage
252252
boolean isMemberReadable(String member,
253-
@Shared("ts2js") @Cached TruffleString.ToJavaStringNode toJavaStringNode) {
253+
@Shared @Cached CastToJavaStringNode toJavaStringNode) {
254254
return getIndex(member, toJavaStringNode) != -1;
255255
}
256256

257257
@ExportMessage
258258
final boolean isMemberModifiable(String member,
259-
@Shared("ts2js") @Cached TruffleString.ToJavaStringNode toJavaStringNode) {
259+
@Shared @Cached CastToJavaStringNode toJavaStringNode) {
260260
return isMemberReadable(member, toJavaStringNode);
261261
}
262262

@@ -267,7 +267,7 @@ final boolean isMemberInsertable(@SuppressWarnings("unused") String member) {
267267

268268
@ExportMessage
269269
Object readMember(String member,
270-
@Shared("ts2js") @Cached TruffleString.ToJavaStringNode toJavaStringNode) throws UnknownIdentifierException {
270+
@Shared @Cached CastToJavaStringNode toJavaStringNode) throws UnknownIdentifierException {
271271
int idx = getIndex(member, toJavaStringNode);
272272
if (idx != -1) {
273273
return CtypesNodes.getValue(stgDict.fieldsTypes[idx], storage, stgDict.fieldsOffsets[idx]);
@@ -277,7 +277,7 @@ Object readMember(String member,
277277

278278
@ExportMessage
279279
void writeMember(String member, Object value,
280-
@Shared("ts2js") @Cached TruffleString.ToJavaStringNode toJavaStringNode) throws UnknownIdentifierException {
280+
@Shared @Cached CastToJavaStringNode toJavaStringNode) throws UnknownIdentifierException {
281281
int idx = getIndex(member, toJavaStringNode);
282282
if (idx != -1) {
283283
CtypesNodes.setValue(stgDict.fieldsTypes[idx], storage, stgDict.fieldsOffsets[idx], value);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/PyCStructTypeBuiltins.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -40,6 +40,7 @@
4040
*/
4141
package com.oracle.graal.python.builtins.modules.ctypes;
4242

43+
import static com.oracle.graal.python.nodes.ErrorMessages.ATTR_NAME_MUST_BE_STRING;
4344
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___NEW__;
4445
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___SETATTR__;
4546
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
@@ -56,15 +57,18 @@
5657
import com.oracle.graal.python.nodes.attributes.WriteAttributeToObjectNode;
5758
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5859
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
60+
import com.oracle.graal.python.nodes.util.CannotCastException;
61+
import com.oracle.graal.python.nodes.util.CastToTruffleStringNode;
5962
import com.oracle.truffle.api.dsl.Cached;
63+
import com.oracle.truffle.api.dsl.Cached.Shared;
6064
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6165
import com.oracle.truffle.api.dsl.NodeFactory;
6266
import com.oracle.truffle.api.dsl.Specialization;
6367
import com.oracle.truffle.api.frame.VirtualFrame;
6468
import com.oracle.truffle.api.strings.TruffleString;
6569

6670
@CoreFunctions(extendClasses = PythonBuiltinClassType.PyCStructType)
67-
public class PyCStructTypeBuiltins extends PythonBuiltins {
71+
public final class PyCStructTypeBuiltins extends PythonBuiltins {
6872

6973
@Override
7074
protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFactories() {
@@ -81,15 +85,30 @@ protected abstract static class NewNode extends StructUnionTypeNewNode {
8185
protected abstract static class SetattrNode extends PythonTernaryBuiltinNode {
8286

8387
@Specialization
84-
protected PNone doStringKey(VirtualFrame frame, Object object, TruffleString key, Object value,
85-
@Cached TruffleString.EqualNode equalNode,
86-
@Cached WriteAttributeToObjectNode writeNode,
87-
@Cached PyCStructUnionTypeUpdateStgDict updateStgDict) {
88+
PNone doStringKey(VirtualFrame frame, Object object, TruffleString key, Object value,
89+
@Shared @Cached TruffleString.EqualNode equalNode,
90+
@Shared @Cached WriteAttributeToObjectNode writeNode,
91+
@Shared @Cached PyCStructUnionTypeUpdateStgDict updateStgDict) {
8892
writeNode.execute(object, key, value);
8993
if (equalNode.execute(key, StructUnionTypeBuiltins.T__fields_, TS_ENCODING)) {
9094
updateStgDict.execute(frame, object, value, true, factory());
9195
}
9296
return PNone.NONE;
9397
}
98+
99+
@Specialization(replaces = "doStringKey")
100+
PNone doGenericKey(VirtualFrame frame, Object object, Object keyObject, Object value,
101+
@Shared @Cached TruffleString.EqualNode equalNode,
102+
@Shared @Cached WriteAttributeToObjectNode writeNode,
103+
@Shared @Cached PyCStructUnionTypeUpdateStgDict updateStgDict,
104+
@Cached CastToTruffleStringNode castKeyToStringNode) {
105+
TruffleString key;
106+
try {
107+
key = castKeyToStringNode.execute(keyObject);
108+
} catch (CannotCastException e) {
109+
throw raise(PythonBuiltinClassType.TypeError, ATTR_NAME_MUST_BE_STRING, keyObject);
110+
}
111+
return doStringKey(frame, object, key, value, equalNode, writeNode, updateStgDict);
112+
}
94113
}
95114
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/StgDictObject.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -81,7 +81,7 @@ public final class StgDictObject extends PDict {
8181
int ndim;
8282
int[] shape;
8383

84-
TruffleString[] fieldsNames;
84+
Object[] fieldsNames;
8585
int[] fieldsOffsets;
8686
FFI_TYPES[] fieldsTypes;
8787

0 commit comments

Comments
 (0)