Skip to content

Commit 06bf257

Browse files
committed
[GR-27398] Do proper callee enter in MayRaiseNode.
PullRequest: graalpython/1385
2 parents 469090a + 222ba16 commit 06bf257

File tree

4 files changed

+98
-69
lines changed

4 files changed

+98
-69
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/GraalPythonModuleBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ public boolean visit(Node node) {
461461
}
462462
return true;
463463
}
464-
});
464+
}, x -> x);
465465

466466
String name = func.getName();
467467
builtinFunc = factory().createFunction(name, func.getEnclosingClassName(),

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.FastCallWithKeywordsArgsToSulongNode;
122122
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.FromCharPointerNode;
123123
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.GetNativeNullNode;
124+
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.MayRaiseErrorResult;
124125
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.MayRaiseNode;
125126
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.ObjectUpcallNode;
126127
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.PCallCapiFunction;
@@ -220,6 +221,7 @@
220221
import com.oracle.graal.python.nodes.classes.IsSubtypeNodeGen;
221222
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
222223
import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
224+
import com.oracle.graal.python.nodes.function.FunctionRootNode;
223225
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
224226
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
225227
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
@@ -278,7 +280,6 @@
278280
import com.oracle.truffle.api.library.CachedLibrary;
279281
import com.oracle.truffle.api.nodes.ControlFlowException;
280282
import com.oracle.truffle.api.nodes.Node;
281-
import com.oracle.truffle.api.nodes.RootNode;
282283
import com.oracle.truffle.api.nodes.UnexpectedResultException;
283284
import com.oracle.truffle.api.object.HiddenKey;
284285
import com.oracle.truffle.api.profiles.BranchProfile;
@@ -2409,53 +2410,59 @@ static Object doDirect(VirtualFrame frame, @SuppressWarnings("unused") PythonMod
24092410
}
24102411
}
24112412

2412-
/*
2413-
* We are creating a special PFunction as a wrapper here - that PFunction has a reference to the
2414-
* wrapped function's CallTarget. Since the wrapped function is a PFunction anyway, we'll have
2415-
* to do the full call logic at some point. But instead of doing it when dispatching to the
2416-
* wrapped function, we copy all relevant bits (signature, mostly) and thus the caller of the
2417-
* wrapper will already do all that work. The root node embedded in the wrapper call target (a
2418-
* MayRaiseNode) then just does a direct call with the frame arguments, without doing anything
2419-
* else anymore. Thus, while there is an extra call, there are really only those Java frames in
2420-
* between that are caused by the Truffle machinery for calls.
2413+
/**
2414+
* Inserts a {@link MayRaiseNode} that wraps the body of the function. This will return a new
2415+
* function object with a rewritten AST. However, we use a cache for the call targets and thus
2416+
* the rewritten-ASTs will also be shared if appropriate.
24212417
*/
24222418
@Builtin(name = "make_may_raise_wrapper", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 2)
24232419
@GenerateNodeFactory
24242420
abstract static class MakeMayRaiseWrapperNode extends PythonBuiltinNode {
24252421
private static final WeakHashMap<RootCallTarget, WeakReference<RootCallTarget>> weakCallTargetMap = new WeakHashMap<>();
24262422

2427-
private static final RootCallTarget createWrapperCt(PFunction func, Object errorResult) {
2428-
CompilerDirectives.transferToInterpreter();
2429-
assert errorResult instanceof Integer || errorResult instanceof Long || errorResult instanceof Double || errorResult == PNone.NONE ||
2430-
InteropLibrary.getUncached().isNull(errorResult) : "invalid wrap";
2431-
PythonLanguage lang = PythonLanguage.getCurrent();
2432-
RootNode rootNode = new MayRaiseNode(lang, func.getSignature(), func.getCallTarget(), errorResult);
2433-
return PythonUtils.getOrCreateCallTarget(rootNode);
2434-
}
2435-
24362423
@Specialization
24372424
@TruffleBoundary
2438-
Object make(PFunction func, Object errorResult) {
2439-
RootCallTarget wrappedCt = func.getCallTarget();
2440-
WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(wrappedCt);
2441-
RootCallTarget wrapperCt = null;
2425+
Object make(PFunction func, Object errorResultObj) {
2426+
RootCallTarget originalCallTarget = func.getCallTarget();
2427+
2428+
WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(originalCallTarget);
2429+
RootCallTarget wrapperCallTarget = null;
24422430
if (wrapperCtRef != null) {
2443-
wrapperCt = wrapperCtRef.get();
2444-
}
2445-
if (wrapperCt == null) {
2446-
wrapperCt = createWrapperCt(func, errorResult);
2447-
weakCallTargetMap.put(wrappedCt, new WeakReference<>(wrapperCt));
2448-
}
2449-
PCode wrappedCode = func.getCode();
2450-
PCode wrapperCode = factory().createCode(PythonBuiltinClassType.PCode, wrapperCt, func.getSignature(),
2451-
0, 0, 0,
2452-
new byte[0], new Object[0], new Object[0],
2453-
new Object[0], new Object[0], new Object[0],
2454-
wrappedCode.getName(), wrappedCode.getName(), 0,
2455-
new byte[0]);
2456-
return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(),
2457-
wrapperCode, func.getGlobals(), func.getDefaults(), func.getKwDefaults(),
2458-
func.getClosure(), func.getCodeStableAssumption(), func.getDefaultsStableAssumption());
2431+
wrapperCallTarget = wrapperCtRef.get();
2432+
}
2433+
if (wrapperCallTarget == null) {
2434+
final MayRaiseErrorResult errorResult = convertToEnum(errorResultObj);
2435+
FunctionRootNode functionRootNode = (FunctionRootNode) func.getFunctionRootNode();
2436+
2437+
// Replace the first expression node with the MayRaiseNode
2438+
functionRootNode = functionRootNode.rewriteWithNewSignature(func.getSignature(), node -> false, body -> MayRaiseNode.create(body, errorResult));
2439+
wrapperCallTarget = PythonUtils.getOrCreateCallTarget(functionRootNode);
2440+
weakCallTargetMap.put(originalCallTarget, new WeakReference<>(wrapperCallTarget));
2441+
}
2442+
2443+
// Although we could theoretically re-use the old function instance, we create a new one
2444+
// to be on the safe side.
2445+
return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(), factory().createCode(wrapperCallTarget), func.getGlobals(), func.getDefaults(),
2446+
func.getKwDefaults(), func.getClosure(), func.getCodeStableAssumption(), func.getCodeStableAssumption());
2447+
}
2448+
2449+
private MayRaiseErrorResult convertToEnum(Object object) {
2450+
if (PGuards.isNone(object)) {
2451+
return MayRaiseErrorResult.NONE;
2452+
} else if (object instanceof Integer) {
2453+
int i = (int) object;
2454+
if (i == -1) {
2455+
return MayRaiseErrorResult.INT;
2456+
}
2457+
} else if (object instanceof Double) {
2458+
double i = (double) object;
2459+
if (i == -1.0) {
2460+
return MayRaiseErrorResult.FLOAT;
2461+
}
2462+
} else if (object instanceof PythonNativeNull || PGuards.isNoValue(object)) {
2463+
return MayRaiseErrorResult.NATIVE_NULL;
2464+
}
2465+
throw raise(PythonErrorType.TypeError, "invalid error result value");
24592466
}
24602467
}
24612468

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/capi/CExtNodes.java

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
5252

5353
import com.oracle.graal.python.PythonLanguage;
54-
import com.oracle.graal.python.builtins.Builtin;
5554
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
5655
import com.oracle.graal.python.builtins.modules.BuiltinFunctions.GetAttrNode;
5756
import com.oracle.graal.python.builtins.modules.PythonCextBuiltins;
@@ -73,6 +72,7 @@
7372
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.DirectUpcallNodeGen;
7473
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.FastCallArgsToSulongNodeGen;
7574
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.FastCallWithKeywordsArgsToSulongNodeGen;
75+
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.GetNativeNullNodeGen;
7676
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.GetTypeMemberNodeGen;
7777
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.IsPointerNodeGen;
7878
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodesFactory.ObjectUpcallNodeGen;
@@ -93,7 +93,6 @@
9393
import com.oracle.graal.python.builtins.objects.complex.PComplex;
9494
import com.oracle.graal.python.builtins.objects.floats.PFloat;
9595
import com.oracle.graal.python.builtins.objects.function.PKeyword;
96-
import com.oracle.graal.python.builtins.objects.function.Signature;
9796
import com.oracle.graal.python.builtins.objects.getsetdescriptor.DescriptorDeleteMarker;
9897
import com.oracle.graal.python.builtins.objects.ints.PInt;
9998
import com.oracle.graal.python.builtins.objects.module.PythonModule;
@@ -109,7 +108,6 @@
109108
import com.oracle.graal.python.nodes.PGuards;
110109
import com.oracle.graal.python.nodes.PNodeWithContext;
111110
import com.oracle.graal.python.nodes.PRaiseNode;
112-
import com.oracle.graal.python.nodes.PRootNode;
113111
import com.oracle.graal.python.nodes.SpecialMethodNames;
114112
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
115113
import com.oracle.graal.python.nodes.call.CallNode;
@@ -118,6 +116,7 @@
118116
import com.oracle.graal.python.nodes.call.special.CallUnaryMethodNode;
119117
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode.LookupAndCallUnaryDynamicNode;
120118
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
119+
import com.oracle.graal.python.nodes.expression.ExpressionNode;
121120
import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
122121
import com.oracle.graal.python.nodes.object.GetClassNode;
123122
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
@@ -137,8 +136,6 @@
137136
import com.oracle.truffle.api.CompilerAsserts;
138137
import com.oracle.truffle.api.CompilerDirectives;
139138
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
140-
import com.oracle.truffle.api.RootCallTarget;
141-
import com.oracle.truffle.api.Truffle;
142139
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
143140
import com.oracle.truffle.api.TruffleLogger;
144141
import com.oracle.truffle.api.dsl.Cached;
@@ -159,7 +156,6 @@
159156
import com.oracle.truffle.api.interop.UnsupportedMessageException;
160157
import com.oracle.truffle.api.interop.UnsupportedTypeException;
161158
import com.oracle.truffle.api.library.CachedLibrary;
162-
import com.oracle.truffle.api.nodes.DirectCallNode;
163159
import com.oracle.truffle.api.nodes.ExplodeLoop;
164160
import com.oracle.truffle.api.nodes.InvalidAssumptionException;
165161
import com.oracle.truffle.api.nodes.Node;
@@ -2473,30 +2469,47 @@ public static PCallCapiFunction getUncached() {
24732469
}
24742470
}
24752471

2476-
// -----------------------------------------------------------------------------------------------------------------
2477-
@Builtin(takesVarArgs = true)
2478-
public static class MayRaiseNode extends PRootNode {
2479-
@Child private DirectCallNode callNode;
2472+
/**
2473+
* Simple enum to abstract over common error indication values used in C extensions. We use this
2474+
* enum instead of concrete values to be able to safely share them between contexts.
2475+
*/
2476+
public enum MayRaiseErrorResult {
2477+
NATIVE_NULL,
2478+
NONE,
2479+
INT,
2480+
FLOAT
2481+
}
2482+
2483+
/**
2484+
* A fake-expression node that wraps an expression node with a {@code try-catch} and any catched
2485+
* Python exception will be transformed to native and the pre-defined error result (specified
2486+
* with enum {@link MayRaiseErrorResult}) will be returned.
2487+
*/
2488+
public static final class MayRaiseNode extends ExpressionNode {
2489+
@Child private ExpressionNode wrappedBody;
24802490
@Child private TransformExceptionToNativeNode transformExceptionToNativeNode;
24812491

2482-
private final Signature signature;
2483-
private final Object errorResult;
2492+
@Child private GetNativeNullNode getNativeNullNode;
2493+
2494+
private final MayRaiseErrorResult errorResult;
24842495

2485-
public MayRaiseNode(PythonLanguage lang, Signature sign, RootCallTarget ct, Object errorResult) {
2486-
super(lang);
2487-
this.signature = sign;
2488-
this.callNode = Truffle.getRuntime().createDirectCallNode(ct);
2496+
MayRaiseNode(ExpressionNode wrappedBody, MayRaiseErrorResult errorResult) {
2497+
this.wrappedBody = wrappedBody;
24892498
this.errorResult = errorResult;
24902499
}
24912500

2501+
public static MayRaiseNode create(ExpressionNode nodeToWrap, MayRaiseErrorResult errorResult) {
2502+
return new MayRaiseNode(nodeToWrap, errorResult);
2503+
}
2504+
24922505
@Override
2493-
public final Object execute(VirtualFrame frame) {
2506+
public Object execute(VirtualFrame frame) {
24942507
try {
2495-
return callNode.call(frame.getArguments());
2508+
return wrappedBody.execute(frame);
24962509
} catch (PException e) {
24972510
// transformExceptionToNativeNode acts as a branch profile
24982511
ensureTransformExceptionToNativeNode().execute(frame, e);
2499-
return errorResult;
2512+
return getErrorResult();
25002513
}
25012514
}
25022515

@@ -2508,14 +2521,22 @@ private TransformExceptionToNativeNode ensureTransformExceptionToNativeNode() {
25082521
return transformExceptionToNativeNode;
25092522
}
25102523

2511-
@Override
2512-
public Signature getSignature() {
2513-
return signature;
2514-
}
2515-
2516-
@Override
2517-
public boolean isPythonInternal() {
2518-
return true;
2524+
private Object getErrorResult() {
2525+
switch (errorResult) {
2526+
case INT:
2527+
return -1;
2528+
case FLOAT:
2529+
return -1.0;
2530+
case NONE:
2531+
return PNone.NONE;
2532+
case NATIVE_NULL:
2533+
if (getNativeNullNode == null) {
2534+
CompilerDirectives.transferToInterpreterAndInvalidate();
2535+
getNativeNullNode = insert(GetNativeNullNodeGen.create());
2536+
}
2537+
return getNativeNullNode.execute();
2538+
}
2539+
throw CompilerDirectives.shouldNotReachHere();
25192540
}
25202541
}
25212542

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/FunctionRootNode.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import com.oracle.graal.python.parser.ExecutionCellSlots;
3636
import com.oracle.graal.python.runtime.ExecutionContext.CalleeContext;
3737
import com.oracle.graal.python.runtime.PythonContext;
38+
import com.oracle.graal.python.util.Function;
3839
import com.oracle.truffle.api.CompilerAsserts;
3940
import com.oracle.truffle.api.CompilerDirectives;
4041
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
@@ -116,8 +117,8 @@ protected RootNode cloneUninitialized() {
116117
* Returns a new function that has its signature replaced and whose body has been modified by
117118
* the given node visitor.
118119
*/
119-
public FunctionRootNode rewriteWithNewSignature(Signature newSignature, NodeVisitor nodeVisitor) {
120-
ExpressionNode newUninitializedBody = NodeUtil.cloneNode(uninitializedBody);
120+
public FunctionRootNode rewriteWithNewSignature(Signature newSignature, NodeVisitor nodeVisitor, Function<ExpressionNode, ExpressionNode> bodyFun) {
121+
ExpressionNode newUninitializedBody = bodyFun.apply(NodeUtil.cloneNode(uninitializedBody));
121122
newUninitializedBody.accept(nodeVisitor);
122123
return new FunctionRootNode(PythonLanguage.getCurrent(), getSourceSection(), functionName, isGenerator, true, getFrameDescriptor(), newUninitializedBody, executionCellSlots,
123124
newSignature);

0 commit comments

Comments
 (0)