Skip to content

Commit 14a6944

Browse files
committed
Preserve capsule name pointer identity
1 parent 2b64f61 commit 14a6944

17 files changed

+178
-185
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_capsule.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,43 @@ class TestPyCapsule(CPyExtTestCase):
8080
cmpfunc=unhandled_error_compare
8181
)
8282

83+
test_PyCapsule_GetName = CPyExtFunction(
84+
lambda args: True,
85+
lambda: (
86+
("hello",),
87+
),
88+
# Test that the returned name is pointer-identical, pybind11 relies on that
89+
code='''int wrap_PyCapsule_Check(char * name) {
90+
PyObject* capsule = PyCapsule_New((void *)1, name, NULL);
91+
return PyCapsule_GetName(capsule) == name;
92+
}
93+
''',
94+
resultspec="i",
95+
argspec='s',
96+
arguments=["char* name"],
97+
callfunction="wrap_PyCapsule_Check",
98+
cmpfunc=unhandled_error_compare
99+
)
100+
101+
test_PyCapsule_SetName = CPyExtFunction(
102+
lambda args: True,
103+
lambda: (
104+
("hello",),
105+
),
106+
# Test that the returned name is pointer-identical, pybind11 relies on that
107+
code='''int wrap_PyCapsule_Check(char * name) {
108+
PyObject* capsule = PyCapsule_New((void *)1, NULL, NULL);
109+
PyCapsule_SetName(capsule, name);
110+
return PyCapsule_GetName(capsule) == name;
111+
}
112+
''',
113+
resultspec="i",
114+
argspec='s',
115+
arguments=["char* name"],
116+
callfunction="wrap_PyCapsule_Check",
117+
cmpfunc=unhandled_error_compare
118+
)
119+
83120
test_PyCapsule_GetContext = CPyExtFunction(
84121
lambda args: True,
85122
lambda: (

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextCapsuleBuiltins.java

Lines changed: 31 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -44,73 +44,69 @@
4444
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.ValueError;
4545
import static com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiCallPath.Direct;
4646
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.ConstCharPtr;
47-
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.ConstCharPtrAsTruffleString;
4847
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Int;
4948
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PY_CAPSULE_DESTRUCTOR;
5049
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Pointer;
5150
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PyObject;
5251
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PyObjectTransfer;
52+
import static com.oracle.graal.python.nodes.ErrorMessages.CALLED_WITH_INCORRECT_NAME;
5353
import static com.oracle.graal.python.nodes.ErrorMessages.CALLED_WITH_INVALID_PY_CAPSULE_OBJECT;
54+
import static com.oracle.graal.python.nodes.ErrorMessages.CALLED_WITH_NULL_POINTER;
5455
import static com.oracle.graal.python.nodes.ErrorMessages.PY_CAPSULE_IMPORT_S_IS_NOT_VALID;
5556
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
5657

5758
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiBinaryBuiltinNode;
5859
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiBuiltin;
5960
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiTernaryBuiltinNode;
6061
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiUnaryBuiltinNode;
61-
import com.oracle.graal.python.builtins.objects.PNone;
6262
import com.oracle.graal.python.builtins.objects.capsule.PyCapsule;
6363
import com.oracle.graal.python.builtins.objects.capsule.PyCapsuleNameMatchesNode;
64-
import com.oracle.graal.python.builtins.objects.cext.common.CArrayWrappers.CStringWrapper;
64+
import com.oracle.graal.python.builtins.objects.cext.capi.transitions.CApiTransitions;
6565
import com.oracle.graal.python.nodes.PRaiseNode;
6666
import com.oracle.graal.python.nodes.StringLiterals;
6767
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
6868
import com.oracle.graal.python.nodes.statement.AbstractImportNode;
6969
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
7070
import com.oracle.truffle.api.dsl.Bind;
7171
import com.oracle.truffle.api.dsl.Cached;
72-
import com.oracle.truffle.api.dsl.Cached.Shared;
7372
import com.oracle.truffle.api.dsl.Fallback;
7473
import com.oracle.truffle.api.dsl.Specialization;
7574
import com.oracle.truffle.api.interop.InteropLibrary;
7675
import com.oracle.truffle.api.library.CachedLibrary;
7776
import com.oracle.truffle.api.nodes.Node;
7877
import com.oracle.truffle.api.strings.TruffleString;
79-
import com.oracle.truffle.api.strings.TruffleString.Encoding;
80-
import com.oracle.truffle.api.strings.TruffleString.GetInternalNativePointerNode;
8178

8279
public final class PythonCextCapsuleBuiltins {
8380

84-
@CApiBuiltin(ret = PyObjectTransfer, args = {Pointer, ConstCharPtrAsTruffleString, PY_CAPSULE_DESTRUCTOR}, call = Direct)
81+
@CApiBuiltin(ret = PyObjectTransfer, args = {Pointer, ConstCharPtr, PY_CAPSULE_DESTRUCTOR}, call = Direct)
8582
abstract static class PyCapsule_New extends CApiTernaryBuiltinNode {
8683
@Specialization
87-
static Object doGeneric(Object pointer, Object name, Object destructor,
84+
static Object doGeneric(Object pointer, Object namePtr, Object destructor,
8885
@Bind("this") Node inliningTarget,
89-
@CachedLibrary(limit = "2") InteropLibrary interopLibrary,
86+
@CachedLibrary(limit = "1") InteropLibrary interopLibrary,
9087
@Cached PythonObjectFactory factory,
9188
@Cached PRaiseNode.Lazy raiseNode) {
9289
if (interopLibrary.isNull(pointer)) {
9390
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INVALID_PY_CAPSULE_OBJECT);
9491
}
95-
Object n = interopLibrary.isNull(name) ? null : name;
96-
PyCapsule capsule = factory.createCapsule(pointer, n);
92+
PyCapsule capsule = factory.createCapsuleNativeName(pointer, interopLibrary.isNull(namePtr) ? null : namePtr);
9793
if (!interopLibrary.isNull(destructor)) {
9894
capsule.registerDestructor(destructor);
9995
}
10096
return capsule;
10197
}
10298
}
10399

104-
@CApiBuiltin(ret = Int, args = {PyObject, ConstCharPtrAsTruffleString}, call = Direct)
105-
public abstract static class PyCapsule_IsValid extends CApiBinaryBuiltinNode {
100+
@CApiBuiltin(ret = Int, args = {PyObject, ConstCharPtr}, call = Direct)
101+
abstract static class PyCapsule_IsValid extends CApiBinaryBuiltinNode {
106102
@Specialization
107-
public static int doCapsule(PyCapsule o, TruffleString name,
103+
static int doCapsule(PyCapsule o, Object namePtr,
108104
@Bind("this") Node inliningTarget,
109105
@Cached PyCapsuleNameMatchesNode nameMatchesNode) {
110106
if (o.getPointer() == null) {
111107
return 0;
112108
}
113-
if (!nameMatchesNode.execute(inliningTarget, name, o.getName())) {
109+
if (!nameMatchesNode.execute(inliningTarget, namePtr, o.getNamePtr())) {
114110
return 0;
115111
}
116112
return 1;
@@ -122,7 +118,7 @@ static Object doError(@SuppressWarnings("unused") Object o, @SuppressWarnings("u
122118
}
123119
}
124120

125-
@CApiBuiltin(ret = Pointer, args = {PyObject, ConstCharPtrAsTruffleString}, call = Direct)
121+
@CApiBuiltin(ret = Pointer, args = {PyObject, ConstCharPtr}, call = Direct)
126122
abstract static class PyCapsule_GetPointer extends CApiBinaryBuiltinNode {
127123
@Specialization
128124
static Object doCapsule(PyCapsule o, Object name,
@@ -132,8 +128,8 @@ static Object doCapsule(PyCapsule o, Object name,
132128
if (o.getPointer() == null) {
133129
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INVALID_PY_CAPSULE_OBJECT, "PyCapsule_GetPointer");
134130
}
135-
if (!nameMatchesNode.execute(inliningTarget, name, o.getName())) {
136-
throw raiseNode.get(inliningTarget).raise(ValueError, PY_CAPSULE_IMPORT_S_IS_NOT_VALID);
131+
if (!nameMatchesNode.execute(inliningTarget, name, o.getNamePtr())) {
132+
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INCORRECT_NAME, "PyCapsule_GetPointer");
137133
}
138134
return o.getPointer();
139135
}
@@ -147,53 +143,15 @@ static Object doError(@SuppressWarnings("unused") Object o, @SuppressWarnings("u
147143

148144
@CApiBuiltin(ret = ConstCharPtr, args = {PyObject}, call = Direct)
149145
abstract static class PyCapsule_GetName extends CApiUnaryBuiltinNode {
150-
private static void checkLegalCapsule(Node inliningTarget, PyCapsule capsule, PRaiseNode.Lazy raiseNode) {
151-
if (capsule.getPointer() == null) {
152-
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INVALID_PY_CAPSULE_OBJECT, "PyCapsule_GetName");
153-
}
154-
}
155-
156-
private static Object tsToNative(TruffleString tname, GetInternalNativePointerNode getInternalNativePointerNode) {
157-
if (tname.isNative()) {
158-
/*
159-
* We assume encoding UTF-8 because it's the most common one and also specified in
160-
* HPy. However, CPython does not actually specify an encoding.
161-
*/
162-
return getInternalNativePointerNode.execute(tname, Encoding.UTF_8);
163-
}
164-
return new CStringWrapper(tname);
165-
}
166146

167-
@Specialization(guards = "isTruffleString(name)")
168-
static Object doTruffleString(PyCapsule o,
169-
@Bind("this") Node inliningTarget,
170-
@Bind("o.getName()") Object name,
171-
@Shared("a") @Cached GetInternalNativePointerNode getInternalNativePointerNode,
172-
@Shared @Cached PRaiseNode.Lazy raiseNode) {
173-
checkLegalCapsule(inliningTarget, o, raiseNode);
174-
175-
// cast to TruffleString guaranteed by the guard
176-
return tsToNative((TruffleString) name, getInternalNativePointerNode);
177-
}
178-
179-
@Specialization(replaces = "doTruffleString")
180-
Object doGeneric(PyCapsule o,
147+
@Specialization
148+
Object get(PyCapsule o,
181149
@Bind("this") Node inliningTarget,
182-
@Bind("o.getName()") Object name,
183-
@Shared("a") @Cached GetInternalNativePointerNode getInternalNativePointerNode,
184-
@Shared @Cached PRaiseNode.Lazy raiseNode) {
185-
checkLegalCapsule(inliningTarget, o, raiseNode);
186-
if (name == null) {
187-
return getNULL();
188-
}
189-
if (name instanceof TruffleString) {
190-
return tsToNative((TruffleString) name, getInternalNativePointerNode);
150+
@Cached PRaiseNode.Lazy raiseNode) {
151+
if (o.getPointer() == null) {
152+
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INVALID_PY_CAPSULE_OBJECT, "PyCapsule_GetName");
191153
}
192-
/*
193-
* If 'name' is not a TruffleString, we assume it is a native pointer and return it
194-
* without further conversion.
195-
*/
196-
return name;
154+
return o.getNamePtr() == null ? getNULL() : o.getNamePtr();
197155
}
198156

199157
@Fallback
@@ -255,7 +213,7 @@ static int doCapsule(PyCapsule o, Object pointer,
255213
@CachedLibrary(limit = "2") InteropLibrary interopLibrary,
256214
@Cached PRaiseNode.Lazy raiseNode) {
257215
if (interopLibrary.isNull(pointer)) {
258-
throw raiseNode.get(inliningTarget).raise(ValueError, PY_CAPSULE_IMPORT_S_IS_NOT_VALID);
216+
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_NULL_POINTER, "PyCapsule_SetPointer");
259217
}
260218

261219
if (o.getPointer() == null) {
@@ -273,27 +231,17 @@ static Object doError(@SuppressWarnings("unused") Object o, @SuppressWarnings("u
273231
}
274232
}
275233

276-
@CApiBuiltin(ret = Int, args = {PyObject, ConstCharPtrAsTruffleString}, call = Direct)
234+
@CApiBuiltin(ret = Int, args = {PyObject, ConstCharPtr}, call = Direct)
277235
abstract static class PyCapsule_SetName extends CApiBinaryBuiltinNode {
278236
@Specialization
279-
static int doCapsuleTruffleString(PyCapsule o, TruffleString name,
237+
static int set(PyCapsule o, Object namePtr,
280238
@Bind("this") Node inliningTarget,
281-
@Shared @Cached PRaiseNode.Lazy raiseNode) {
282-
if (o.getPointer() == null) {
283-
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INVALID_PY_CAPSULE_OBJECT, "PyCapsule_SetName");
284-
}
285-
o.setName(name);
286-
return 0;
287-
}
288-
289-
@Specialization(guards = "isNoValue(name)")
290-
static int doCapsuleNone(PyCapsule o, @SuppressWarnings("unused") PNone name,
291-
@Bind("this") Node inliningTarget,
292-
@Shared @Cached PRaiseNode.Lazy raiseNode) {
239+
@CachedLibrary(limit = "1") InteropLibrary lib,
240+
@Cached PRaiseNode.Lazy raiseNode) {
293241
if (o.getPointer() == null) {
294242
throw raiseNode.get(inliningTarget).raise(ValueError, CALLED_WITH_INVALID_PY_CAPSULE_OBJECT, "PyCapsule_SetName");
295243
}
296-
o.setName(null);
244+
o.setNamePtr(lib.isNull(namePtr) ? null : namePtr);
297245
return 0;
298246
}
299247

@@ -345,17 +293,19 @@ static Object doError(@SuppressWarnings("unused") Object o, @SuppressWarnings("u
345293
}
346294
}
347295

348-
@CApiBuiltin(ret = Pointer, args = {ConstCharPtrAsTruffleString, Int}, call = Direct)
296+
@CApiBuiltin(ret = Pointer, args = {ConstCharPtr, Int}, call = Direct)
349297
abstract static class PyCapsule_Import extends CApiBinaryBuiltinNode {
350298
@Specialization
351-
static Object doGeneric(TruffleString name, @SuppressWarnings("unused") int noBlock,
299+
static Object doGeneric(Object namePtr, @SuppressWarnings("unused") int noBlock,
352300
@Bind("this") Node inliningTarget,
301+
@Cached CApiTransitions.CharPtrToPythonNode charPtrToPythonNode,
353302
@Cached PyCapsuleNameMatchesNode nameMatchesNode,
354303
@Cached TruffleString.CodePointLengthNode codePointLengthNode,
355304
@Cached TruffleString.IndexOfStringNode indexOfStringNode,
356305
@Cached TruffleString.SubstringNode substringNode,
357306
@Cached ReadAttributeFromObjectNode getAttrNode,
358307
@Cached PRaiseNode.Lazy raiseNode) {
308+
TruffleString name = (TruffleString) charPtrToPythonNode.execute(namePtr);
359309
TruffleString trace = name;
360310
Object object = null;
361311
while (trace != null) {
@@ -377,7 +327,7 @@ static Object doGeneric(TruffleString name, @SuppressWarnings("unused") int noBl
377327

378328
/* compare attribute name to module.name by hand */
379329
PyCapsule capsule = object instanceof PyCapsule ? (PyCapsule) object : null;
380-
if (capsule != null && PyCapsule_IsValid.doCapsule(capsule, name, inliningTarget, nameMatchesNode) == 1) {
330+
if (capsule != null && PyCapsule_IsValid.doCapsule(capsule, namePtr, inliningTarget, nameMatchesNode) == 1) {
381331
return capsule.getPointer();
382332
} else {
383333
throw raiseNode.get(inliningTarget).raise(AttributeError, PY_CAPSULE_IMPORT_S_IS_NOT_VALID, name);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cjkcodecs/CodecsCNModuleBuiltins.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
import com.oracle.graal.python.builtins.modules.cjkcodecs.DBCSMap.MappingType;
6161
import com.oracle.graal.python.builtins.modules.cjkcodecs.MultibyteCodec.CodecType;
6262
import com.oracle.graal.python.builtins.objects.capsule.PyCapsule;
63-
import com.oracle.graal.python.builtins.objects.capsule.PyCapsuleNameMatchesNode;
6463
import com.oracle.graal.python.builtins.objects.module.PythonModule;
6564
import com.oracle.graal.python.lib.PyUnicodeCheckNode;
6665
import com.oracle.graal.python.nodes.PRaiseNode;
@@ -125,7 +124,6 @@ static Object getcodec(Object encoding,
125124
@Cached TruffleString.EqualNode isEqual,
126125
@Cached PyUnicodeCheckNode unicodeCheckNode,
127126
@Cached CastToTruffleStringNode asUTF8Node,
128-
@Cached PyCapsuleNameMatchesNode nameMatchesNode,
129127
@Cached PythonObjectFactory factory,
130128
@Cached PRaiseNode.Lazy raiseNode) {
131129

@@ -138,8 +136,8 @@ static Object getcodec(Object encoding,
138136
throw raiseNode.get(inliningTarget).raise(LookupError, NO_SUCH_CODEC_IS_SUPPORTED);
139137
}
140138

141-
PyCapsule codecobj = factory.createCapsule(codec, PyMultibyteCodec_CAPSULE_NAME);
142-
return createCodec(inliningTarget, codecobj, nameMatchesNode, factory, raiseNode);
139+
PyCapsule codecobj = factory.createCapsuleJavaName(codec, PyMultibyteCodec_CAPSULE_NAME);
140+
return createCodec(inliningTarget, codecobj, factory, raiseNode);
143141
}
144142
}
145143

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cjkcodecs/CodecsHKModuleBuiltins.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
import com.oracle.graal.python.builtins.modules.cjkcodecs.DBCSMap.MappingType;
6161
import com.oracle.graal.python.builtins.modules.cjkcodecs.MultibyteCodec.CodecType;
6262
import com.oracle.graal.python.builtins.objects.capsule.PyCapsule;
63-
import com.oracle.graal.python.builtins.objects.capsule.PyCapsuleNameMatchesNode;
6463
import com.oracle.graal.python.builtins.objects.module.PythonModule;
6564
import com.oracle.graal.python.lib.PyUnicodeCheckNode;
6665
import com.oracle.graal.python.nodes.PRaiseNode;
@@ -117,7 +116,6 @@ static Object getcodec(Object encoding,
117116
@Cached TruffleString.EqualNode isEqual,
118117
@Cached PyUnicodeCheckNode unicodeCheckNode,
119118
@Cached CastToTruffleStringNode asUTF8Node,
120-
@Cached PyCapsuleNameMatchesNode nameMatchesNode,
121119
@Cached PythonObjectFactory factory,
122120
@Cached PRaiseNode.Lazy raiseNode) {
123121

@@ -130,8 +128,8 @@ static Object getcodec(Object encoding,
130128
throw raiseNode.get(inliningTarget).raise(LookupError, NO_SUCH_CODEC_IS_SUPPORTED);
131129
}
132130

133-
PyCapsule codecobj = factory.createCapsule(codec, PyMultibyteCodec_CAPSULE_NAME);
134-
return createCodec(inliningTarget, codecobj, nameMatchesNode, factory, raiseNode);
131+
PyCapsule codecobj = factory.createCapsuleJavaName(codec, PyMultibyteCodec_CAPSULE_NAME);
132+
return createCodec(inliningTarget, codecobj, factory, raiseNode);
135133
}
136134
}
137135

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cjkcodecs/CodecsISO2022ModuleBuiltins.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
import com.oracle.graal.python.builtins.PythonBuiltins;
6060
import com.oracle.graal.python.builtins.modules.cjkcodecs.MultibyteCodec.CodecType;
6161
import com.oracle.graal.python.builtins.objects.capsule.PyCapsule;
62-
import com.oracle.graal.python.builtins.objects.capsule.PyCapsuleNameMatchesNode;
6362
import com.oracle.graal.python.builtins.objects.module.PythonModule;
6463
import com.oracle.graal.python.lib.PyUnicodeCheckNode;
6564
import com.oracle.graal.python.nodes.PRaiseNode;
@@ -127,7 +126,6 @@ static Object getcodec(Object encoding,
127126
@Cached TruffleString.EqualNode isEqual,
128127
@Cached PyUnicodeCheckNode unicodeCheckNode,
129128
@Cached CastToTruffleStringNode asUTF8Node,
130-
@Cached PyCapsuleNameMatchesNode nameMatchesNode,
131129
@Cached PythonObjectFactory factory,
132130
@Cached PRaiseNode.Lazy raiseNode) {
133131

@@ -140,8 +138,8 @@ static Object getcodec(Object encoding,
140138
throw raiseNode.get(inliningTarget).raise(LookupError, NO_SUCH_CODEC_IS_SUPPORTED);
141139
}
142140

143-
PyCapsule codecobj = factory.createCapsule(codec, PyMultibyteCodec_CAPSULE_NAME);
144-
return createCodec(inliningTarget, codecobj, nameMatchesNode, factory, raiseNode);
141+
PyCapsule codecobj = factory.createCapsuleJavaName(codec, PyMultibyteCodec_CAPSULE_NAME);
142+
return createCodec(inliningTarget, codecobj, factory, raiseNode);
145143
}
146144
}
147145

0 commit comments

Comments
 (0)