Skip to content

Commit 4e77634

Browse files
committed
Avoid transfer-to-interpreter in with node and correctly handle
exceptions.
1 parent dfe641b commit 4e77634

File tree

2 files changed

+119
-65
lines changed
  • graalpython
    • com.oracle.graal.python.test/src/tests
    • com.oracle.graal.python/src/com/oracle/graal/python/nodes/statement

2 files changed

+119
-65
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_with.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,38 +27,103 @@
2727
a = 5
2828

2929
LOG = []
30+
LOG1 = []
31+
LOG2 = []
32+
LOG3 = []
33+
34+
35+
class Context:
36+
37+
def __init__(self, log, suppress_exception, raise_exception):
38+
self._log = log
39+
self._suppress = suppress_exception
40+
self._raise = raise_exception
3041

31-
class Sample:
3242
def __enter__(self):
33-
LOG.append("__enter__")
43+
self._log.append("__enter__")
3444
return self
3545

3646
def __exit__(self, type, value, trace):
37-
LOG.append("type: %s" % type)
38-
LOG.append("value: %s" % value)
47+
self._log.append("type: %s" % type)
48+
self._log.append("value: %s" % value)
3949
# LOG.append("trace: %s" % trace) # trace back is not supported yet
40-
return False
50+
return self._suppress
4151

4252
def do_something(self):
43-
bar = 1/0
53+
self._log.append("do_something")
54+
bar = 1
55+
if self._raise:
56+
bar = bar / 0
4457
return bar + 10
4558

46-
def test_with():
59+
60+
def payload(log, suppress_exception, raise_exception, do_return):
61+
a = 5
4762
try:
48-
with Sample() as sample:
49-
a = 5
50-
sample.do_something()
63+
with Context(log, suppress_exception, raise_exception) as sample:
64+
if do_return:
65+
a = sample.do_something()
66+
return a
67+
else:
68+
a = sample.do_something()
5169
except ZeroDivisionError:
52-
LOG.append("Exception has been thrown correctly")
70+
log.append("Exception has been thrown correctly")
5371

5472
else:
55-
LOG.append("This is not correct!!")
73+
log.append("no exception or exception suppressed")
5674

5775
finally:
58-
LOG.append("a = %s" % a)
76+
log.append("a = %s" % a)
77+
78+
return a
79+
80+
81+
def test_with_dont_suppress():
82+
payload(LOG, False, True, False)
83+
assert LOG == [
84+
"__enter__" ,
85+
"do_something" ,
86+
"type: <class 'ZeroDivisionError'>" ,
87+
"value: division by zero" ,
88+
"Exception has been thrown correctly" ,
89+
"a = 5"
90+
], "was: " + str(LOG)
91+
92+
93+
def test_with_suppress():
94+
payload(LOG1, True, True, False)
95+
assert LOG1 == [ "__enter__" ,
96+
"do_something" ,
97+
"type: <class 'ZeroDivisionError'>" ,
98+
"value: division by zero" ,
99+
"no exception or exception suppressed" ,
100+
"a = 5"
101+
], "was: " + str(LOG1)
102+
103+
104+
def with_return(ctx):
105+
with ctx as sample:
106+
return ctx.do_something()
107+
return None
108+
109+
110+
def test_with_return():
111+
result = payload(LOG2, False, False, True)
112+
assert result == 11
113+
assert LOG2 == [ "__enter__",
114+
"do_something",
115+
"type: None",
116+
"value: None",
117+
"a = 11",
118+
], "was: " + str(LOG2)
119+
59120

60-
assert LOG[0] == "__enter__"
61-
assert LOG[1] == "type: <class 'ZeroDivisionError'>"
62-
assert LOG[2] == "value: division by zero"
63-
assert LOG[3] == "Exception has been thrown correctly"
64-
assert LOG[4] == "a = 5"
121+
def test_with_return_and_exception():
122+
result = payload(LOG3, True, False, True)
123+
assert result == 11
124+
assert LOG3 == [ "__enter__",
125+
"do_something",
126+
"type: None",
127+
"value: None",
128+
"a = 11",
129+
], "was: " + str(LOG3)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/statement/WithNode.java

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,23 @@
2626
package com.oracle.graal.python.nodes.statement;
2727

2828
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
29-
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ZeroDivisionError;
3029

3130
import com.oracle.graal.python.builtins.objects.PNone;
3231
import com.oracle.graal.python.builtins.objects.function.PKeyword;
33-
import com.oracle.graal.python.builtins.objects.function.PythonCallable;
3432
import com.oracle.graal.python.builtins.objects.object.PythonObject;
3533
import com.oracle.graal.python.nodes.PNode;
3634
import com.oracle.graal.python.nodes.argument.CreateArgumentsNode;
3735
import com.oracle.graal.python.nodes.attributes.GetAttributeNode;
3836
import com.oracle.graal.python.nodes.call.CallDispatchNode;
37+
import com.oracle.graal.python.nodes.datamodel.IsCallableNode;
3938
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
4039
import com.oracle.graal.python.nodes.frame.WriteNode;
4140
import com.oracle.graal.python.runtime.exception.PException;
42-
import com.oracle.truffle.api.CompilerDirectives;
41+
import com.oracle.truffle.api.dsl.Cached;
4342
import com.oracle.truffle.api.dsl.NodeChild;
4443
import com.oracle.truffle.api.dsl.NodeChildren;
4544
import com.oracle.truffle.api.dsl.Specialization;
4645
import com.oracle.truffle.api.frame.VirtualFrame;
47-
import com.oracle.truffle.api.nodes.UnexpectedResultException;
4846

4947
@NodeChildren({@NodeChild(value = "withContext", type = PNode.class)})
5048
public abstract class WithNode extends StatementNode {
@@ -55,7 +53,7 @@ public abstract class WithNode extends StatementNode {
5553
@Child private GetAttributeNode exitGetter;
5654
@Child private CallDispatchNode enterDispatch;
5755
@Child private CallDispatchNode exitDispatch;
58-
@Child private CastToBooleanNode exitResultIsTrueish;
56+
@Child private CastToBooleanNode toBooleanNode;
5957
@Child private CreateArgumentsNode createArgs;
6058

6159
protected WithNode(PNode targetNode, PNode body) {
@@ -65,7 +63,7 @@ protected WithNode(PNode targetNode, PNode body) {
6563
this.exitGetter = GetAttributeNode.create();
6664
this.enterDispatch = CallDispatchNode.create("__enter__");
6765
this.exitDispatch = CallDispatchNode.create("__enter__");
68-
this.exitResultIsTrueish = CastToBooleanNode.createIfTrueNode();
66+
this.toBooleanNode = CastToBooleanNode.createIfTrueNode();
6967
this.createArgs = CreateArgumentsNode.create();
7068
}
7169

@@ -93,62 +91,53 @@ public PNode getTargetNode() {
9391
}
9492

9593
@Specialization
96-
protected Object runWith(VirtualFrame frame, PythonObject withObject) {
94+
protected Object runWith(VirtualFrame frame, PythonObject withObject,
95+
@Cached("create()") IsCallableNode isCallableNode,
96+
@Cached("create()") IsCallableNode isExitCallableNode) {
97+
9798
boolean gotException = false;
9899
Object enterCallable = enterGetter.execute(withObject, "__enter__");
99100
Object exitCallable = exitGetter.execute(withObject, "__exit__");
100-
try {
101-
applyValues(frame, enterDispatch.executeCall(PythonCallable.expect(enterCallable), createArgs.execute(withObject), new PKeyword[0]));
102-
} catch (UnexpectedResultException e1) {
103-
CompilerDirectives.transferToInterpreter();
104-
throw raise(TypeError, "%s is not callable", e1.getResult());
101+
102+
if (isCallableNode.execute(enterCallable)) {
103+
applyValues(frame, enterDispatch.executeCall(enterCallable, createArgs.execute(withObject), new PKeyword[0]));
104+
} else {
105+
throw raise(TypeError, "%p is not callable", enterCallable);
105106
}
107+
106108
try {
107109
body.execute(frame);
108-
} catch (RuntimeException exception) {
109-
CompilerDirectives.transferToInterpreter();
110+
} catch (PException exception) {
110111
gotException = true;
111-
return handleException(withObject, exception);
112+
return handleException(withObject, exception, isExitCallableNode);
112113
} finally {
113114
if (!gotException) {
114-
try {
115-
return exitDispatch.executeCall(PythonCallable.expect(exitCallable), createArgs.execute(withObject, PNone.NONE, PNone.NONE, PNone.NONE),
116-
new PKeyword[0]);
117-
} catch (UnexpectedResultException e1) {
118-
CompilerDirectives.transferToInterpreter();
119-
throw raise(TypeError, "%s is not callable", e1.getResult());
115+
if (isExitCallableNode.execute(exitCallable)) {
116+
exitDispatch.executeCall(exitCallable, createArgs.execute(withObject, PNone.NONE, PNone.NONE, PNone.NONE), new PKeyword[0]);
117+
} else {
118+
throw raise(TypeError, "%p is not callable", exitCallable);
120119
}
121120
}
122121
}
123-
assert false;
124-
return null;
122+
return PNone.NONE;
125123
}
126124

127-
private Object handleException(PythonObject withObject, RuntimeException e) {
128-
RuntimeException exception = e;
129-
PythonCallable exitCallable = null;
130-
try {
131-
exitCallable = PythonCallable.expect(exitGetter.execute(withObject, "__exit__"));
132-
} catch (UnexpectedResultException e1) {
133-
CompilerDirectives.transferToInterpreter();
134-
throw raise(TypeError, "%s is not callable", e1.getResult());
135-
}
136-
if (exception instanceof ArithmeticException && exception.getMessage().endsWith("divide by zero")) {
137-
// TODO: no ArithmeticExceptions should propagate outside of operations
138-
exception = raise(ZeroDivisionError, "division by zero");
125+
private Object handleException(PythonObject withObject, PException e, IsCallableNode isExitCallableNode) {
126+
Object exitCallable = exitGetter.execute(withObject, "__exit__");
127+
if (!isExitCallableNode.execute(exitCallable)) {
128+
throw raise(TypeError, "%p is not callable", exitCallable);
139129
}
140-
if (exception instanceof PException) {
141-
PException pException = (PException) exception;
142-
pException.getExceptionObject().reifyException();
143-
Object type = pException.getType();
144-
Object value = pException.getExceptionObject();
145-
Object trace = pException.getExceptionObject().getTraceback(factory());
146-
Object returnValue = exitDispatch.executeCall(exitCallable, createArgs.execute(withObject, type, value, trace), new PKeyword[0]);
147-
// Corner cases:
148-
if (exitResultIsTrueish.executeWith(returnValue)) {
149-
return returnValue;
150-
}
130+
131+
e.getExceptionObject().reifyException();
132+
Object type = e.getType();
133+
Object value = e.getExceptionObject();
134+
Object trace = e.getExceptionObject().getTraceback(factory());
135+
Object returnValue = exitDispatch.executeCall(exitCallable, createArgs.execute(withObject, type, value, trace), new PKeyword[0]);
136+
// If exit handler returns 'true', suppress
137+
if (toBooleanNode.executeWith(returnValue)) {
138+
return PNone.NONE;
151139
}
152-
throw exception;
140+
// else re-raise exception
141+
throw e;
153142
}
154143
}

0 commit comments

Comments
 (0)