Skip to content

Commit 222ba16

Browse files
committed
Insert MayRaiseNode in function body instead of wrapping the call
1 parent 8fb6c02 commit 222ba16

File tree

2 files changed

+58
-90
lines changed

2 files changed

+58
-90
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@
221221
import com.oracle.graal.python.nodes.classes.IsSubtypeNodeGen;
222222
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
223223
import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
224+
import com.oracle.graal.python.nodes.function.FunctionRootNode;
224225
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
225226
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
226227
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
@@ -279,7 +280,6 @@
279280
import com.oracle.truffle.api.library.CachedLibrary;
280281
import com.oracle.truffle.api.nodes.ControlFlowException;
281282
import com.oracle.truffle.api.nodes.Node;
282-
import com.oracle.truffle.api.nodes.RootNode;
283283
import com.oracle.truffle.api.nodes.UnexpectedResultException;
284284
import com.oracle.truffle.api.object.HiddenKey;
285285
import com.oracle.truffle.api.profiles.BranchProfile;
@@ -2410,57 +2410,44 @@ static Object doDirect(VirtualFrame frame, @SuppressWarnings("unused") PythonMod
24102410
}
24112411
}
24122412

2413-
/*
2414-
* We are creating a special PFunction as a wrapper here - that PFunction has a reference to the
2415-
* wrapped function's CallTarget. Since the wrapped function is a PFunction anyway, we'll have
2416-
* to do the full call logic at some point. But instead of doing it when dispatching to the
2417-
* wrapped function, we copy all relevant bits (signature, mostly) and thus the caller of the
2418-
* wrapper will already do all that work. The root node embedded in the wrapper call target (a
2419-
* MayRaiseNode) then just does a direct call with the frame arguments, without doing anything
2420-
* else anymore. Thus, while there is an extra call, there are really only those Java frames in
2421-
* between that are caused by the Truffle machinery for calls.
2413+
/**
2414+
* Inserts a {@link MayRaiseNode} that wraps the body of the function. This will return a new
2415+
* function object with a rewritten AST. However, we use a cache for the call targets and thus
2416+
* the rewritten-ASTs will also be shared if appropriate.
24222417
*/
24232418
@Builtin(name = "make_may_raise_wrapper", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 2)
24242419
@GenerateNodeFactory
24252420
abstract static class MakeMayRaiseWrapperNode extends PythonBuiltinNode {
24262421
private static final WeakHashMap<RootCallTarget, WeakReference<RootCallTarget>> weakCallTargetMap = new WeakHashMap<>();
24272422

2428-
private static RootCallTarget createWrapperCt(PFunction func, MayRaiseErrorResult errorResult) {
2429-
CompilerDirectives.transferToInterpreter();
2430-
PythonLanguage lang = PythonLanguage.getCurrent();
2431-
RootNode rootNode = new MayRaiseNode(lang, func.getSignature(), func.getCallTarget(), errorResult);
2432-
return PythonUtils.getOrCreateCallTarget(rootNode);
2433-
}
2434-
24352423
@Specialization
24362424
@TruffleBoundary
24372425
Object make(PFunction func, Object errorResultObj) {
2438-
MayRaiseErrorResult errorResult = convertToEnum(errorResultObj);
2439-
2440-
RootCallTarget wrappedCt = func.getCallTarget();
2441-
WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(wrappedCt);
2442-
RootCallTarget wrapperCt = null;
2426+
RootCallTarget originalCallTarget = func.getCallTarget();
2427+
2428+
WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(originalCallTarget);
2429+
RootCallTarget wrapperCallTarget = null;
24432430
if (wrapperCtRef != null) {
2444-
wrapperCt = wrapperCtRef.get();
2445-
}
2446-
if (wrapperCt == null) {
2447-
wrapperCt = createWrapperCt(func, errorResult);
2448-
weakCallTargetMap.put(wrappedCt, new WeakReference<>(wrapperCt));
2449-
}
2450-
PCode wrappedCode = func.getCode();
2451-
PCode wrapperCode = factory().createCode(PythonBuiltinClassType.PCode, wrapperCt, func.getSignature(),
2452-
0, 0, 0,
2453-
new byte[0], new Object[0], new Object[0],
2454-
new Object[0], new Object[0], new Object[0],
2455-
wrappedCode.getName(), wrappedCode.getName(), 0,
2456-
new byte[0]);
2457-
return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(),
2458-
wrapperCode, func.getGlobals(), func.getDefaults(), func.getKwDefaults(),
2459-
func.getClosure(), func.getCodeStableAssumption(), func.getDefaultsStableAssumption());
2460-
}
2461-
2431+
wrapperCallTarget = wrapperCtRef.get();
2432+
}
2433+
if (wrapperCallTarget == null) {
2434+
final MayRaiseErrorResult errorResult = convertToEnum(errorResultObj);
2435+
FunctionRootNode functionRootNode = (FunctionRootNode) func.getFunctionRootNode();
2436+
2437+
// Replace the first expression node with the MayRaiseNode
2438+
functionRootNode = functionRootNode.rewriteWithNewSignature(func.getSignature(), node -> false, body -> MayRaiseNode.create(body, errorResult));
2439+
wrapperCallTarget = PythonUtils.getOrCreateCallTarget(functionRootNode);
2440+
weakCallTargetMap.put(originalCallTarget, new WeakReference<>(wrapperCallTarget));
2441+
}
2442+
2443+
// Although we could theoretically re-use the old function instance, we create a new one
2444+
// to be on the safe side.
2445+
return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(), factory().createCode(wrapperCallTarget), func.getGlobals(), func.getDefaults(),
2446+
func.getKwDefaults(), func.getClosure(), func.getCodeStableAssumption(), func.getCodeStableAssumption());
2447+
}
2448+
24622449
private MayRaiseErrorResult convertToEnum(Object object) {
2463-
if (PGuards.isNone(object) ) {
2450+
if (PGuards.isNone(object)) {
24642451
return MayRaiseErrorResult.NONE;
24652452
} else if (object instanceof Integer) {
24662453
int i = (int) object;
@@ -2472,7 +2459,7 @@ private MayRaiseErrorResult convertToEnum(Object object) {
24722459
if (i == -1.0) {
24732460
return MayRaiseErrorResult.FLOAT;
24742461
}
2475-
} else if (object instanceof PythonNativeNull) {
2462+
} else if (object instanceof PythonNativeNull || PGuards.isNoValue(object)) {
24762463
return MayRaiseErrorResult.NATIVE_NULL;
24772464
}
24782465
throw raise(PythonErrorType.TypeError, "invalid error result value");

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/capi/CExtNodes.java

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
5252

5353
import com.oracle.graal.python.PythonLanguage;
54-
import com.oracle.graal.python.builtins.Builtin;
5554
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
5655
import com.oracle.graal.python.builtins.modules.BuiltinFunctions.GetAttrNode;
5756
import com.oracle.graal.python.builtins.modules.PythonCextBuiltins;
@@ -93,9 +92,7 @@
9392
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
9493
import com.oracle.graal.python.builtins.objects.complex.PComplex;
9594
import com.oracle.graal.python.builtins.objects.floats.PFloat;
96-
import com.oracle.graal.python.builtins.objects.function.PArguments;
9795
import com.oracle.graal.python.builtins.objects.function.PKeyword;
98-
import com.oracle.graal.python.builtins.objects.function.Signature;
9996
import com.oracle.graal.python.builtins.objects.getsetdescriptor.DescriptorDeleteMarker;
10097
import com.oracle.graal.python.builtins.objects.ints.PInt;
10198
import com.oracle.graal.python.builtins.objects.module.PythonModule;
@@ -111,23 +108,21 @@
111108
import com.oracle.graal.python.nodes.PGuards;
112109
import com.oracle.graal.python.nodes.PNodeWithContext;
113110
import com.oracle.graal.python.nodes.PRaiseNode;
114-
import com.oracle.graal.python.nodes.PRootNode;
115111
import com.oracle.graal.python.nodes.SpecialMethodNames;
116112
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
117113
import com.oracle.graal.python.nodes.call.CallNode;
118-
import com.oracle.graal.python.nodes.call.CallTargetInvokeNode;
119114
import com.oracle.graal.python.nodes.call.special.CallBinaryMethodNode;
120115
import com.oracle.graal.python.nodes.call.special.CallTernaryMethodNode;
121116
import com.oracle.graal.python.nodes.call.special.CallUnaryMethodNode;
122117
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode.LookupAndCallUnaryDynamicNode;
123118
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
119+
import com.oracle.graal.python.nodes.expression.ExpressionNode;
124120
import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
125121
import com.oracle.graal.python.nodes.object.GetClassNode;
126122
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
127123
import com.oracle.graal.python.nodes.truffle.PythonTypes;
128124
import com.oracle.graal.python.nodes.util.CannotCastException;
129125
import com.oracle.graal.python.nodes.util.CastToJavaLongLossyNode;
130-
import com.oracle.graal.python.runtime.ExecutionContext.CalleeContext;
131126
import com.oracle.graal.python.runtime.PythonContext;
132127
import com.oracle.graal.python.runtime.PythonCore;
133128
import com.oracle.graal.python.runtime.PythonOptions;
@@ -140,9 +135,7 @@
140135
import com.oracle.truffle.api.Assumption;
141136
import com.oracle.truffle.api.CompilerAsserts;
142137
import com.oracle.truffle.api.CompilerDirectives;
143-
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
144138
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
145-
import com.oracle.truffle.api.RootCallTarget;
146139
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
147140
import com.oracle.truffle.api.TruffleLogger;
148141
import com.oracle.truffle.api.dsl.Cached;
@@ -2475,50 +2468,48 @@ public static PCallCapiFunction getUncached() {
24752468
return CExtNodesFactory.PCallCapiFunctionNodeGen.getUncached();
24762469
}
24772470
}
2478-
2471+
2472+
/**
2473+
* Simple enum to abstract over common error indication values used in C extensions. We use this
2474+
* enum instead of concrete values to be able to safely share them between contexts.
2475+
*/
24792476
public enum MayRaiseErrorResult {
2480-
NATIVE_NULL, NONE, INT, FLOAT
2477+
NATIVE_NULL,
2478+
NONE,
2479+
INT,
2480+
FLOAT
24812481
}
24822482

2483-
// -----------------------------------------------------------------------------------------------------------------
2484-
@Builtin(takesVarArgs = true)
2485-
public static class MayRaiseNode extends PRootNode {
2486-
@Child private CallTargetInvokeNode callTargetInvokeNode;
2483+
/**
2484+
* A fake-expression node that wraps an expression node with a {@code try-catch} and any catched
2485+
* Python exception will be transformed to native and the pre-defined error result (specified
2486+
* with enum {@link MayRaiseErrorResult}) will be returned.
2487+
*/
2488+
public static final class MayRaiseNode extends ExpressionNode {
2489+
@Child private ExpressionNode wrappedBody;
24872490
@Child private TransformExceptionToNativeNode transformExceptionToNativeNode;
2488-
@Child private CalleeContext calleeContext;
2489-
2491+
24902492
@Child private GetNativeNullNode getNativeNullNode;
24912493

2492-
private final Signature signature;
24932494
private final MayRaiseErrorResult errorResult;
24942495

2495-
public MayRaiseNode(PythonLanguage lang, Signature sign, RootCallTarget ct, MayRaiseErrorResult errorResult) {
2496-
super(lang);
2497-
this.signature = sign;
2498-
this.callTargetInvokeNode = CallTargetInvokeNode.create(ct, false, false);
2499-
this.calleeContext = CalleeContext.create();
2496+
MayRaiseNode(ExpressionNode wrappedBody, MayRaiseErrorResult errorResult) {
2497+
this.wrappedBody = wrappedBody;
25002498
this.errorResult = errorResult;
25012499
}
25022500

2501+
public static MayRaiseNode create(ExpressionNode nodeToWrap, MayRaiseErrorResult errorResult) {
2502+
return new MayRaiseNode(nodeToWrap, errorResult);
2503+
}
2504+
25032505
@Override
2504-
public final Object execute(VirtualFrame frame) {
2505-
Object[] arguments = frame.getArguments();
2506-
int userArgumentLength = PArguments.getUserArgumentLength(arguments);
2507-
Object[] newArguments = PArguments.create(userArgumentLength);
2508-
// just copy user arguments, varargs and kwargs
2509-
System.arraycopy(arguments, PArguments.USER_ARGUMENTS_OFFSET, newArguments, PArguments.USER_ARGUMENTS_OFFSET, userArgumentLength);
2510-
PArguments.setVariableArguments(newArguments, PArguments.getVariableArguments(arguments));
2511-
PArguments.setKeywordArguments(newArguments, PArguments.getKeywordArguments(arguments));
2512-
2513-
calleeContext.enter(frame);
2506+
public Object execute(VirtualFrame frame) {
25142507
try {
2515-
return callTargetInvokeNode.execute(frame, null, PArguments.getGlobals(arguments), PArguments.getClosure(arguments), newArguments);
2508+
return wrappedBody.execute(frame);
25162509
} catch (PException e) {
25172510
// transformExceptionToNativeNode acts as a branch profile
25182511
ensureTransformExceptionToNativeNode().execute(frame, e);
25192512
return getErrorResult();
2520-
} finally {
2521-
calleeContext.exit(frame, this);
25222513
}
25232514
}
25242515

@@ -2529,15 +2520,15 @@ private TransformExceptionToNativeNode ensureTransformExceptionToNativeNode() {
25292520
}
25302521
return transformExceptionToNativeNode;
25312522
}
2532-
2523+
25332524
private Object getErrorResult() {
2534-
switch(errorResult) {
2525+
switch (errorResult) {
25352526
case INT:
25362527
return -1;
25372528
case FLOAT:
25382529
return -1.0;
25392530
case NONE:
2540-
return PNone.NONE;
2531+
return PNone.NONE;
25412532
case NATIVE_NULL:
25422533
if (getNativeNullNode == null) {
25432534
CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -2547,16 +2538,6 @@ private Object getErrorResult() {
25472538
}
25482539
throw CompilerDirectives.shouldNotReachHere();
25492540
}
2550-
2551-
@Override
2552-
public Signature getSignature() {
2553-
return signature;
2554-
}
2555-
2556-
@Override
2557-
public boolean isPythonInternal() {
2558-
return true;
2559-
}
25602541
}
25612542

25622543
// -----------------------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)