Skip to content

Commit e322120

Browse files
committed
Add test for invalid pointer sharing for primitive wrappers
1 parent 5d6ec27 commit e322120

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

graalpython/com.oracle.graal.python.cext/src/capi.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,11 @@ int PyTruffle_Debug(void *arg) {
876876
return 0;
877877
}
878878

879+
int PyTruffle_ToNative(void *arg) {
880+
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_ToNative", arg);
881+
return 0;
882+
}
883+
879884
int truffle_ptr_compare(void* x, void* y, int op) {
880885
switch (op) {
881886
case Py_LT:

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_misc.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,46 @@ def compile_module(self, name):
228228
cmpfunc=unhandled_error_compare
229229
)
230230

231+
# Tests if wrapped Java primitive values do not share the same
232+
# native pointer.
233+
test_primitive_sharing = CPyExtFunction(
234+
lambda args: True,
235+
lambda: (
236+
(123.0, ),
237+
),
238+
code="""
239+
// internal function defined in 'capi.c'
240+
int PyTruffle_ToNative(void *);
241+
242+
PyObject* primitive_sharing(PyObject* val) {
243+
Py_ssize_t val_refcnt = Py_REFCNT(val);
244+
// assume val's refcnt is X > 0
245+
Py_INCREF(val);
246+
// val's refcnt should now be X+1
247+
248+
double dval = PyFloat_AsDouble(val);
249+
250+
PyTruffle_ToNative(val);
251+
252+
// a fresh object with the same value
253+
PyObject *val1 = PyFloat_FromDouble(dval);
254+
PyTruffle_ToNative(val1);
255+
256+
// now, kill it
257+
Py_DECREF(val1);
258+
259+
// reset val's refcnt to X
260+
Py_DECREF(val);
261+
262+
return val_refcnt == Py_REFCNT(val) ? Py_True : Py_False;
263+
}
264+
""",
265+
resultspec="O",
266+
argspec="O",
267+
arguments=["PyObject* val"],
268+
cmpfunc=unhandled_error_compare
269+
)
270+
231271
test_PyOS_double_to_string = CPyExtFunction(
232272
_reference_format_float,
233273
lambda: (

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextBuiltins.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@
212212
import com.oracle.truffle.api.CompilerDirectives;
213213
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
214214
import com.oracle.truffle.api.RootCallTarget;
215+
import com.oracle.truffle.api.TruffleLanguage.Env;
215216
import com.oracle.truffle.api.TruffleLogger;
216217
import com.oracle.truffle.api.dsl.Cached;
217218
import com.oracle.truffle.api.dsl.Cached.Exclusive;
@@ -232,6 +233,7 @@
232233
import com.oracle.truffle.api.interop.UnsupportedMessageException;
233234
import com.oracle.truffle.api.interop.UnsupportedTypeException;
234235
import com.oracle.truffle.api.library.CachedLibrary;
236+
import com.oracle.truffle.api.nodes.LanguageInfo;
235237
import com.oracle.truffle.api.nodes.Node;
236238
import com.oracle.truffle.api.nodes.RootNode;
237239
import com.oracle.truffle.api.object.DynamicObjectLibrary;
@@ -241,6 +243,7 @@
241243
import com.oracle.truffle.api.profiles.LoopConditionProfile;
242244
import com.oracle.truffle.api.profiles.ValueProfile;
243245
import com.oracle.truffle.api.utilities.CyclicAssumption;
246+
import com.oracle.truffle.llvm.api.Toolchain;
244247

245248
@CoreFunctions(defineModule = PythonCextBuiltins.PYTHON_CEXT)
246249
@GenerateNodeFactory
@@ -269,6 +272,15 @@ public void initialize(Python3Core core) {
269272
builtinConstants.put("PyGILState_Release", new PyGILStateRelease());
270273
}
271274

275+
@Override
276+
public void postInitialize(Python3Core core) {
277+
super.postInitialize(core);
278+
if (!core.getContext().getOption(PythonOptions.EnableDebuggingBuiltins)) {
279+
PythonModule mod = core.lookupBuiltinModule(PYTHON_CEXT);
280+
mod.setAttribute("PyTruffle_ToNative", PNone.NO_VALUE);
281+
}
282+
}
283+
272284
@FunctionalInterface
273285
public interface TernaryFunction<T1, T2, T3, R> {
274286
R apply(T1 arg0, T2 arg1, T3 arg2);
@@ -2376,15 +2388,33 @@ Object tssDelete(Object key,
23762388
}
23772389
}
23782390

2391+
// directly called without landing function
23792392
@Builtin(name = "PyTruffle_Debug", takesVarArgs = true)
23802393
@GenerateNodeFactory
23812394
public abstract static class PyTruffleDebugNode extends PythonBuiltinNode {
23822395
@Specialization
23832396
@TruffleBoundary
2384-
public Object doIt(Object[] args,
2397+
static Object doIt(Object[] args,
23852398
@Cached DebugNode debugNode) {
23862399
debugNode.execute(args);
23872400
return PNone.NONE;
23882401
}
23892402
}
2403+
2404+
// directly called without landing function
2405+
@Builtin(name = "PyTruffle_ToNative", minNumOfPositionalArgs = 1)
2406+
@GenerateNodeFactory
2407+
public abstract static class PyTruffleToNativeNode extends PythonUnaryBuiltinNode {
2408+
@Specialization
2409+
@TruffleBoundary
2410+
Object doIt(Object object) {
2411+
Env env = getContext().getEnv();
2412+
LanguageInfo llvmInfo = env.getInternalLanguages().get(PythonLanguage.LLVM_LANGUAGE);
2413+
Toolchain toolchain = env.lookup(llvmInfo, Toolchain.class);
2414+
if ("native".equals(toolchain.getIdentifier())) {
2415+
InteropLibrary.getUncached().toNative(object);
2416+
}
2417+
return PNone.NONE;
2418+
}
2419+
}
23902420
}

0 commit comments

Comments
 (0)