Skip to content

Commit 045d78b

Browse files
committed
[GR-10590] Avoid transfer-to-interpreter in with statement.
PullRequest: graalpython/96
2 parents dfe641b + 1492952 commit 045d78b

File tree

7 files changed

+202
-87
lines changed

7 files changed

+202
-87
lines changed

graalpython/com.oracle.graal.python.cext/src/capi.c

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,20 @@
3838
*/
3939
#include "capi.h"
4040

41-
#define FORCE_TO_NATIVE(__obj__) (polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_Set_Ptr", (__obj__), truffle_is_handle_to_managed((__obj__)) ? (__obj__) : truffle_deref_handle_for_managed(__obj__)))
41+
42+
MUST_INLINE static void force_to_native(void* obj) {
43+
if (polyglot_is_value(obj)) {
44+
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_Set_Ptr", obj, truffle_deref_handle_for_managed(obj));
45+
}
46+
}
4247

4348
static void initialize_type_structure(PyTypeObject* structure, const char* typname) {
4449
// explicit type cast is required because the type flags are not yet initialized !
4550
PyTypeObject* ptype = polyglot_as__typeobject(UPCALL_CEXT_O("PyTruffle_Type", polyglot_from_string(typname, SRC_CS)));
4651

4752
// We eagerly create a native pointer for all builtin types. This is necessary for pointer comparisons to work correctly.
4853
// TODO Remove this as soon as this is properly supported.
49-
FORCE_TO_NATIVE(ptype);
54+
force_to_native(ptype);
5055

5156
unsigned long original_flags = structure->tp_flags;
5257
Py_ssize_t basicsize = structure->tp_basicsize;
@@ -59,37 +64,37 @@ static void initialize_type_structure(PyTypeObject* structure, const char* typna
5964
static void initialize_globals() {
6065
// None
6166
PyObject* jnone = UPCALL_CEXT_O("Py_None");
62-
FORCE_TO_NATIVE(jnone);
67+
force_to_native(jnone);
6368
truffle_assign_managed(&_Py_NoneStruct, jnone);
6469

6570
// NotImplemented
6671
void *jnotimpl = UPCALL_CEXT_O("Py_NotImplemented");
67-
FORCE_TO_NATIVE(jnotimpl);
72+
force_to_native(jnotimpl);
6873
truffle_assign_managed(&_Py_NotImplementedStruct, jnotimpl);
6974

7075
// Ellipsis
7176
void *jellipsis = UPCALL_CEXT_O("Py_Ellipsis");
72-
FORCE_TO_NATIVE(jellipsis);
77+
force_to_native(jellipsis);
7378
truffle_assign_managed(&_Py_EllipsisObject, jellipsis);
7479

7580
// True, False
7681
void *jtrue = UPCALL_CEXT_O("Py_True");
77-
FORCE_TO_NATIVE(jtrue);
82+
force_to_native(jtrue);
7883
truffle_assign_managed(&_Py_TrueStruct, polyglot_as__longobject(jtrue));
7984
void *jfalse = UPCALL_CEXT_O("Py_False");
80-
FORCE_TO_NATIVE(jfalse);
85+
force_to_native(jfalse);
8186
truffle_assign_managed(&_Py_FalseStruct, polyglot_as__longobject(jfalse));
8287

8388
// error marker
8489
void *jerrormarker = UPCALL_CEXT_O("Py_ErrorHandler");
85-
FORCE_TO_NATIVE(jerrormarker);
90+
force_to_native(jerrormarker);
8691
truffle_assign_managed(&marker_struct, jerrormarker);
8792
}
8893

8994
static void initialize_bufferprocs() {
90-
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_SetBufferProcs", to_java((PyObject*)&PyBytes_Type), (getbufferproc)bytes_buffer_getbuffer, (releasebufferproc)NULL);
91-
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_SetBufferProcs", to_java((PyObject*)&PyByteArray_Type), (getbufferproc)NULL, (releasebufferproc)NULL);
92-
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_SetBufferProcs", to_java((PyObject*)&PyBuffer_Type), (getbufferproc)bufferdecorator_getbuffer, (releasebufferproc)NULL);
95+
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_SetBufferProcs", native_to_java((PyObject*)&PyBytes_Type), (getbufferproc)bytes_buffer_getbuffer, (releasebufferproc)NULL);
96+
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_SetBufferProcs", native_to_java((PyObject*)&PyByteArray_Type), (getbufferproc)NULL, (releasebufferproc)NULL);
97+
polyglot_invoke(PY_TRUFFLE_CEXT, "PyTruffle_SetBufferProcs", native_to_java((PyObject*)&PyBuffer_Type), (getbufferproc)bufferdecorator_getbuffer, (releasebufferproc)NULL);
9398
}
9499

95100
__attribute__((constructor))
@@ -205,7 +210,7 @@ PyObject* to_sulong(void *o) {
205210

206211
/** to be used from Java code only; reads native 'ob_type' field */
207212
void* get_ob_type(PyObject* obj) {
208-
return native_to_java(obj->ob_type);
213+
return native_to_java((PyObject*)obj->ob_type);
209214
}
210215

211216
typedef struct PyObjectHandle {

graalpython/com.oracle.graal.python.test/src/tests/test_with.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,38 +27,103 @@
2727
a = 5
2828

2929
LOG = []
30+
LOG1 = []
31+
LOG2 = []
32+
LOG3 = []
33+
34+
35+
class Context:
36+
37+
def __init__(self, log, suppress_exception, raise_exception):
38+
self._log = log
39+
self._suppress = suppress_exception
40+
self._raise = raise_exception
3041

31-
class Sample:
3242
def __enter__(self):
33-
LOG.append("__enter__")
43+
self._log.append("__enter__")
3444
return self
3545

3646
def __exit__(self, type, value, trace):
37-
LOG.append("type: %s" % type)
38-
LOG.append("value: %s" % value)
47+
self._log.append("type: %s" % type)
48+
self._log.append("value: %s" % value)
3949
# LOG.append("trace: %s" % trace) # trace back is not supported yet
40-
return False
50+
return self._suppress
4151

4252
def do_something(self):
43-
bar = 1/0
53+
self._log.append("do_something")
54+
bar = 1
55+
if self._raise:
56+
bar = bar / 0
4457
return bar + 10
4558

46-
def test_with():
59+
60+
def payload(log, suppress_exception, raise_exception, do_return):
61+
a = 5
4762
try:
48-
with Sample() as sample:
49-
a = 5
50-
sample.do_something()
63+
with Context(log, suppress_exception, raise_exception) as sample:
64+
if do_return:
65+
a = sample.do_something()
66+
return a
67+
else:
68+
a = sample.do_something()
5169
except ZeroDivisionError:
52-
LOG.append("Exception has been thrown correctly")
70+
log.append("Exception has been thrown correctly")
5371

5472
else:
55-
LOG.append("This is not correct!!")
73+
log.append("no exception or exception suppressed")
5674

5775
finally:
58-
LOG.append("a = %s" % a)
76+
log.append("a = %s" % a)
77+
78+
return a
79+
80+
81+
def test_with_dont_suppress():
82+
payload(LOG, False, True, False)
83+
assert LOG == [
84+
"__enter__" ,
85+
"do_something" ,
86+
"type: <class 'ZeroDivisionError'>" ,
87+
"value: division by zero" ,
88+
"Exception has been thrown correctly" ,
89+
"a = 5"
90+
], "was: " + str(LOG)
91+
92+
93+
def test_with_suppress():
94+
payload(LOG1, True, True, False)
95+
assert LOG1 == [ "__enter__" ,
96+
"do_something" ,
97+
"type: <class 'ZeroDivisionError'>" ,
98+
"value: division by zero" ,
99+
"no exception or exception suppressed" ,
100+
"a = 5"
101+
], "was: " + str(LOG1)
102+
103+
104+
def with_return(ctx):
105+
with ctx as sample:
106+
return ctx.do_something()
107+
return None
108+
109+
110+
def test_with_return():
111+
result = payload(LOG2, False, False, True)
112+
assert result == 11
113+
assert LOG2 == [ "__enter__",
114+
"do_something",
115+
"type: None",
116+
"value: None",
117+
"a = 11",
118+
], "was: " + str(LOG2)
119+
59120

60-
assert LOG[0] == "__enter__"
61-
assert LOG[1] == "type: <class 'ZeroDivisionError'>"
62-
assert LOG[2] == "value: division by zero"
63-
assert LOG[3] == "Exception has been thrown correctly"
64-
assert LOG[4] == "a = 5"
121+
def test_with_return_and_exception():
122+
result = payload(LOG3, True, False, True)
123+
assert result == 11
124+
assert LOG3 == [ "__enter__",
125+
"do_something",
126+
"type: None",
127+
"value: None",
128+
"a = 11",
129+
], "was: " + str(LOG3)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/TruffleCextBuiltins.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,9 +1199,18 @@ private GetClassNode getClassNode() {
11991199
abstract static class PyTruffle_Set_Ptr extends NativeBuiltin {
12001200

12011201
@Specialization
1202-
PythonObjectNativeWrapper doPythonObject(PythonObjectNativeWrapper nativeWrapper, TruffleObject ptr) {
1203-
nativeWrapper.setNativePointer(ptr);
1204-
return nativeWrapper;
1202+
int doPythonObject(PythonAbstractObject nativeWrapper, TruffleObject ptr) {
1203+
return doNativeWrapper(nativeWrapper.getNativeWrapper(), ptr);
1204+
}
1205+
1206+
@Specialization
1207+
int doNativeWrapper(PythonObjectNativeWrapper nativeWrapper, TruffleObject ptr) {
1208+
if (nativeWrapper.isNative()) {
1209+
PythonContext.getSingleNativeContextAssumption().invalidate();
1210+
} else {
1211+
nativeWrapper.setNativePointer(ptr);
1212+
}
1213+
return 0;
12051214
}
12061215
}
12071216

@@ -1210,12 +1219,16 @@ PythonObjectNativeWrapper doPythonObject(PythonObjectNativeWrapper nativeWrapper
12101219
abstract static class PyTruffle_SetBufferProcs extends NativeBuiltin {
12111220

12121221
@Specialization
1213-
Object doPythonObject(PythonClass obj, Object getBufferProc, Object releaseBufferProc) {
1214-
PythonClassNativeWrapper nativeWrapper = obj.getNativeWrapper();
1222+
Object doNativeWrapper(PythonClassNativeWrapper nativeWrapper, Object getBufferProc, Object releaseBufferProc) {
12151223
nativeWrapper.setGetBufferProc(getBufferProc);
12161224
nativeWrapper.setReleaseBufferProc(releaseBufferProc);
12171225
return PNone.NO_VALUE;
12181226
}
1227+
1228+
@Specialization
1229+
Object doPythonObject(PythonClass obj, Object getBufferProc, Object releaseBufferProc) {
1230+
return doNativeWrapper(obj.getNativeWrapper(), getBufferProc, releaseBufferProc);
1231+
}
12191232
}
12201233

12211234
@Builtin(name = "PyTruffle_ThreadState_GetDict", fixedNumOfArguments = 0)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/NativeWrappers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public Object getNativePointer() {
6262

6363
public void setNativePointer(Object nativePointer) {
6464
// we should set the pointer just once
65-
assert this.nativePointer == null || this.nativePointer.equals(nativePointer);
65+
assert this.nativePointer == null || this.nativePointer.equals(nativePointer) || nativePointer == null;
6666
this.nativePointer = nativePointer;
6767
}
6868

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/PythonObjectNativeWrapperMR.java

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import com.oracle.graal.python.builtins.objects.cext.NativeWrappers.PySequenceArrayWrapper;
5252
import com.oracle.graal.python.builtins.objects.cext.NativeWrappers.PythonNativeWrapper;
5353
import com.oracle.graal.python.builtins.objects.cext.NativeWrappers.PythonObjectNativeWrapper;
54+
import com.oracle.graal.python.builtins.objects.cext.PythonObjectNativeWrapperMRFactory.PAsPointerNodeGen;
5455
import com.oracle.graal.python.builtins.objects.cext.PythonObjectNativeWrapperMRFactory.ReadNativeMemberNodeGen;
5556
import com.oracle.graal.python.builtins.objects.cext.PythonObjectNativeWrapperMRFactory.ToPyObjectNodeGen;
5657
import com.oracle.graal.python.builtins.objects.cext.PythonObjectNativeWrapperMRFactory.WriteNativeMemberNodeGen;
@@ -74,8 +75,10 @@
7475
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
7576
import com.oracle.graal.python.nodes.object.GetClassNode;
7677
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
78+
import com.oracle.graal.python.runtime.PythonContext;
7779
import com.oracle.graal.python.runtime.interop.PythonMessageResolution;
7880
import com.oracle.graal.python.runtime.sequence.PSequence;
81+
import com.oracle.truffle.api.Assumption;
7982
import com.oracle.truffle.api.CompilerDirectives;
8083
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
8184
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -529,7 +532,7 @@ public Object access(Object object) {
529532

530533
@Resolve(message = "TO_NATIVE")
531534
abstract static class ToNativeNode extends Node {
532-
@Child private ToPyObjectNode toPyObjectNode = ToPyObjectNodeGen.create();
535+
@Child private ToPyObjectNode toPyObjectNode = ToPyObjectNode.create();
533536

534537
Object access(PythonNativeWrapper obj) {
535538
if (!obj.isNative()) {
@@ -543,7 +546,7 @@ Object access(PythonNativeWrapper obj) {
543546
abstract static class IsPointerNode extends Node {
544547
@Child private Node isPointerNode;
545548

546-
Object access(PythonNativeWrapper obj) {
549+
boolean access(PythonNativeWrapper obj) {
547550
return obj.isNative() && (!(obj.getNativePointer() instanceof TruffleObject) || ForeignAccess.sendIsPointer(getIsPointerNode(), (TruffleObject) obj.getNativePointer()));
548551
}
549552

@@ -558,11 +561,26 @@ private Node getIsPointerNode() {
558561

559562
@Resolve(message = "AS_POINTER")
560563
abstract static class AsPointerNode extends Node {
561-
@Child private Node asPointerNode;
564+
@Child private PAsPointerNode pAsPointerNode = PAsPointerNode.create();
562565

563566
long access(PythonNativeWrapper obj) {
567+
return pAsPointerNode.execute(obj);
568+
}
569+
}
570+
571+
abstract static class PAsPointerNode extends PBaseNode {
572+
@Child private Node asPointerNode;
573+
574+
public abstract long execute(PythonNativeWrapper o);
575+
576+
@Specialization(assumptions = "getSingleNativeContextAssumption()")
577+
long doFast(PythonNativeWrapper obj) {
564578
// the native pointer object must either be a TruffleObject or a primitive
565579
Object nativePointer = obj.getNativePointer();
580+
return ensureLong(nativePointer);
581+
}
582+
583+
private long ensureLong(Object nativePointer) {
566584
if (nativePointer instanceof TruffleObject) {
567585
if (asPointerNode == null) {
568586
CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -575,8 +593,22 @@ long access(PythonNativeWrapper obj) {
575593
}
576594
}
577595
return (long) nativePointer;
596+
}
578597

598+
@Specialization(replaces = "doFast")
599+
long doSlow(PythonNativeWrapper obj,
600+
@Cached("create()") ToPyObjectNode toPyObjectNode) {
601+
return ensureLong(toPyObjectNode.execute(obj));
579602
}
603+
604+
protected Assumption getSingleNativeContextAssumption() {
605+
return PythonContext.getSingleNativeContextAssumption();
606+
}
607+
608+
public static PAsPointerNode create() {
609+
return PAsPointerNodeGen.create();
610+
}
611+
580612
}
581613

582614
abstract static class ToPyObjectNode extends TransformToNativeNode {
@@ -662,5 +694,9 @@ private CExtNodes.ToSulongNode getToSulongNode() {
662694
}
663695
return toSulongNode;
664696
}
697+
698+
public static ToPyObjectNode create() {
699+
return ToPyObjectNodeGen.create();
700+
}
665701
}
666702
}

0 commit comments

Comments
 (0)