|
121 | 121 | import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.FastCallWithKeywordsArgsToSulongNode;
|
122 | 122 | import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.FromCharPointerNode;
|
123 | 123 | import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.GetNativeNullNode;
|
| 124 | +import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.MayRaiseErrorResult; |
124 | 125 | import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.MayRaiseNode;
|
125 | 126 | import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.ObjectUpcallNode;
|
126 | 127 | import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.PCallCapiFunction;
|
|
220 | 221 | import com.oracle.graal.python.nodes.classes.IsSubtypeNodeGen;
|
221 | 222 | import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
|
222 | 223 | import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
|
| 224 | +import com.oracle.graal.python.nodes.function.FunctionRootNode; |
223 | 225 | import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
|
224 | 226 | import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
|
225 | 227 | import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
|
|
278 | 280 | import com.oracle.truffle.api.library.CachedLibrary;
|
279 | 281 | import com.oracle.truffle.api.nodes.ControlFlowException;
|
280 | 282 | import com.oracle.truffle.api.nodes.Node;
|
281 |
| -import com.oracle.truffle.api.nodes.RootNode; |
282 | 283 | import com.oracle.truffle.api.nodes.UnexpectedResultException;
|
283 | 284 | import com.oracle.truffle.api.object.HiddenKey;
|
284 | 285 | import com.oracle.truffle.api.profiles.BranchProfile;
|
@@ -2409,53 +2410,59 @@ static Object doDirect(VirtualFrame frame, @SuppressWarnings("unused") PythonMod
|
2409 | 2410 | }
|
2410 | 2411 | }
|
2411 | 2412 |
|
2412 |
| - /* |
2413 |
| - * We are creating a special PFunction as a wrapper here - that PFunction has a reference to the |
2414 |
| - * wrapped function's CallTarget. Since the wrapped function is a PFunction anyway, we'll have |
2415 |
| - * to do the full call logic at some point. But instead of doing it when dispatching to the |
2416 |
| - * wrapped function, we copy all relevant bits (signature, mostly) and thus the caller of the |
2417 |
| - * wrapper will already do all that work. The root node embedded in the wrapper call target (a |
2418 |
| - * MayRaiseNode) then just does a direct call with the frame arguments, without doing anything |
2419 |
| - * else anymore. Thus, while there is an extra call, there are really only those Java frames in |
2420 |
| - * between that are caused by the Truffle machinery for calls. |
| 2413 | + /** |
| 2414 | + * Inserts a {@link MayRaiseNode} that wraps the body of the function. This will return a new |
| 2415 | + * function object with a rewritten AST. However, we use a cache for the call targets and thus |
| 2416 | + * the rewritten-ASTs will also be shared if appropriate. |
2421 | 2417 | */
|
2422 | 2418 | @Builtin(name = "make_may_raise_wrapper", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 2)
|
2423 | 2419 | @GenerateNodeFactory
|
2424 | 2420 | abstract static class MakeMayRaiseWrapperNode extends PythonBuiltinNode {
|
2425 | 2421 | private static final WeakHashMap<RootCallTarget, WeakReference<RootCallTarget>> weakCallTargetMap = new WeakHashMap<>();
|
2426 | 2422 |
|
2427 |
| - private static final RootCallTarget createWrapperCt(PFunction func, Object errorResult) { |
2428 |
| - CompilerDirectives.transferToInterpreter(); |
2429 |
| - assert errorResult instanceof Integer || errorResult instanceof Long || errorResult instanceof Double || errorResult == PNone.NONE || |
2430 |
| - InteropLibrary.getUncached().isNull(errorResult) : "invalid wrap"; |
2431 |
| - PythonLanguage lang = PythonLanguage.getCurrent(); |
2432 |
| - RootNode rootNode = new MayRaiseNode(lang, func.getSignature(), func.getCallTarget(), errorResult); |
2433 |
| - return PythonUtils.getOrCreateCallTarget(rootNode); |
2434 |
| - } |
2435 |
| - |
2436 | 2423 | @Specialization
|
2437 | 2424 | @TruffleBoundary
|
2438 |
| - Object make(PFunction func, Object errorResult) { |
2439 |
| - RootCallTarget wrappedCt = func.getCallTarget(); |
2440 |
| - WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(wrappedCt); |
2441 |
| - RootCallTarget wrapperCt = null; |
| 2425 | + Object make(PFunction func, Object errorResultObj) { |
| 2426 | + RootCallTarget originalCallTarget = func.getCallTarget(); |
| 2427 | + |
| 2428 | + WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(originalCallTarget); |
| 2429 | + RootCallTarget wrapperCallTarget = null; |
2442 | 2430 | if (wrapperCtRef != null) {
|
2443 |
| - wrapperCt = wrapperCtRef.get(); |
2444 |
| - } |
2445 |
| - if (wrapperCt == null) { |
2446 |
| - wrapperCt = createWrapperCt(func, errorResult); |
2447 |
| - weakCallTargetMap.put(wrappedCt, new WeakReference<>(wrapperCt)); |
2448 |
| - } |
2449 |
| - PCode wrappedCode = func.getCode(); |
2450 |
| - PCode wrapperCode = factory().createCode(PythonBuiltinClassType.PCode, wrapperCt, func.getSignature(), |
2451 |
| - 0, 0, 0, |
2452 |
| - new byte[0], new Object[0], new Object[0], |
2453 |
| - new Object[0], new Object[0], new Object[0], |
2454 |
| - wrappedCode.getName(), wrappedCode.getName(), 0, |
2455 |
| - new byte[0]); |
2456 |
| - return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(), |
2457 |
| - wrapperCode, func.getGlobals(), func.getDefaults(), func.getKwDefaults(), |
2458 |
| - func.getClosure(), func.getCodeStableAssumption(), func.getDefaultsStableAssumption()); |
| 2431 | + wrapperCallTarget = wrapperCtRef.get(); |
| 2432 | + } |
| 2433 | + if (wrapperCallTarget == null) { |
| 2434 | + final MayRaiseErrorResult errorResult = convertToEnum(errorResultObj); |
| 2435 | + FunctionRootNode functionRootNode = (FunctionRootNode) func.getFunctionRootNode(); |
| 2436 | + |
| 2437 | + // Replace the first expression node with the MayRaiseNode |
| 2438 | + functionRootNode = functionRootNode.rewriteWithNewSignature(func.getSignature(), node -> false, body -> MayRaiseNode.create(body, errorResult)); |
| 2439 | + wrapperCallTarget = PythonUtils.getOrCreateCallTarget(functionRootNode); |
| 2440 | + weakCallTargetMap.put(originalCallTarget, new WeakReference<>(wrapperCallTarget)); |
| 2441 | + } |
| 2442 | + |
| 2443 | + // Although we could theoretically re-use the old function instance, we create a new one |
| 2444 | + // to be on the safe side. |
| 2445 | + return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(), factory().createCode(wrapperCallTarget), func.getGlobals(), func.getDefaults(), |
| 2446 | + func.getKwDefaults(), func.getClosure(), func.getCodeStableAssumption(), func.getCodeStableAssumption()); |
| 2447 | + } |
| 2448 | + |
| 2449 | + private MayRaiseErrorResult convertToEnum(Object object) { |
| 2450 | + if (PGuards.isNone(object)) { |
| 2451 | + return MayRaiseErrorResult.NONE; |
| 2452 | + } else if (object instanceof Integer) { |
| 2453 | + int i = (int) object; |
| 2454 | + if (i == -1) { |
| 2455 | + return MayRaiseErrorResult.INT; |
| 2456 | + } |
| 2457 | + } else if (object instanceof Double) { |
| 2458 | + double i = (double) object; |
| 2459 | + if (i == -1.0) { |
| 2460 | + return MayRaiseErrorResult.FLOAT; |
| 2461 | + } |
| 2462 | + } else if (object instanceof PythonNativeNull || PGuards.isNoValue(object)) { |
| 2463 | + return MayRaiseErrorResult.NATIVE_NULL; |
| 2464 | + } |
| 2465 | + throw raise(PythonErrorType.TypeError, "invalid error result value"); |
2459 | 2466 | }
|
2460 | 2467 | }
|
2461 | 2468 |
|
|
0 commit comments