Skip to content

Commit ed4b7ee

Browse files
committed
share call targets for MayRaiseNode wrapper functions and do not store the original function and error result in the AST
1 parent 1956b82 commit ed4b7ee

File tree

2 files changed

+64
-174
lines changed

2 files changed

+64
-174
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import static com.oracle.graal.python.runtime.exception.PythonErrorType.OverflowError;
4848

4949
import java.io.PrintWriter;
50+
import java.lang.ref.WeakReference;
5051
import java.math.BigInteger;
5152
import java.nio.ByteBuffer;
5253
import java.nio.CharBuffer;
@@ -61,6 +62,7 @@
6162
import java.util.ArrayList;
6263
import java.util.Arrays;
6364
import java.util.List;
65+
import java.util.WeakHashMap;
6466
import java.util.logging.Level;
6567

6668
import com.oracle.graal.python.PythonLanguage;
@@ -113,10 +115,7 @@
113115
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.FastCallArgsToSulongNode;
114116
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.FastCallWithKeywordsArgsToSulongNode;
115117
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.GetNativeNullNode;
116-
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.MayRaiseBinaryNode;
117118
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.MayRaiseNode;
118-
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.MayRaiseTernaryNode;
119-
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.MayRaiseUnaryNode;
120119
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.ObjectUpcallNode;
121120
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.PCallCapiFunction;
122121
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.PRaiseNativeNode;
@@ -128,9 +127,6 @@
128127
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.TransformExceptionToNativeNode;
129128
import com.oracle.graal.python.builtins.objects.cext.CExtNodes.VoidPtrToJavaNode;
130129
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.CastToNativeLongNodeGen;
131-
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.MayRaiseBinaryNodeGen;
132-
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.MayRaiseTernaryNodeGen;
133-
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.MayRaiseUnaryNodeGen;
134130
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.PRaiseNativeNodeGen;
135131
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.ToJavaNodeGen;
136132
import com.oracle.graal.python.builtins.objects.cext.CExtNodesFactory.TransformExceptionToNativeNodeGen;
@@ -2318,49 +2314,52 @@ static Object doDirect(VirtualFrame frame, @SuppressWarnings("unused") PythonMod
23182314
}
23192315
}
23202316

2317+
/*
2318+
* We are creating a special PFunction as a wrapper here - that PFunction has a reference to the
2319+
* wrapped function's CallTarget. Since the wrapped function is a PFunction anyway, we'll have
2320+
* to do the full call logic at some point. But instead of doing it when dispatching to the
2321+
* wrapped function, we copy all relevant bits (signature, mostly) and thus the caller of the
2322+
* wrapper will already do all that work. The root node embedded in the wrapper call target (a
2323+
* MayRaiseNode) then just does a direct call with the frame arguments, without doing anything
2324+
* else anymore. Thus, while there is an extra call, there are really only those Java frames in
2325+
* between that are caused by the Truffle machinery for calls.
2326+
*/
23212327
@Builtin(name = "make_may_raise_wrapper", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 2)
23222328
@GenerateNodeFactory
23232329
abstract static class MakeMayRaiseWrapperNode extends PythonBuiltinNode {
2324-
private static final Builtin unaryBuiltin = MayRaiseUnaryNode.class.getAnnotation(Builtin.class);
2325-
private static final Builtin binaryBuiltin = MayRaiseBinaryNode.class.getAnnotation(Builtin.class);
2326-
private static final Builtin ternaryBuiltin = MayRaiseTernaryNode.class.getAnnotation(Builtin.class);
2327-
private static final Builtin varargsBuiltin = MayRaiseNode.class.getAnnotation(Builtin.class);
2330+
private static final WeakHashMap<RootCallTarget, WeakReference<RootCallTarget>> weakCallTargetMap = new WeakHashMap<>();
23282331

2329-
@Specialization
2330-
@TruffleBoundary
2331-
Object make(PFunction func, Object errorResult,
2332-
@Exclusive @CachedLanguage PythonLanguage lang) {
2332+
private static final RootCallTarget createWrapperCt(PFunction func, Object errorResult) {
23332333
CompilerDirectives.transferToInterpreter();
2334-
RootNode rootNode = null;
2335-
Signature funcSignature = func.getSignature();
2336-
if (funcSignature.takesPositionalOnly()) {
2337-
switch (funcSignature.getMaxNumOfPositionalArgs()) {
2338-
case 1:
2339-
rootNode = new BuiltinFunctionRootNode(lang, unaryBuiltin,
2340-
new StandaloneBuiltinFactory<PythonUnaryBuiltinNode>(MayRaiseUnaryNodeGen.create(func, errorResult)),
2341-
true);
2342-
break;
2343-
case 2:
2344-
rootNode = new BuiltinFunctionRootNode(lang, binaryBuiltin,
2345-
new StandaloneBuiltinFactory<PythonBinaryBuiltinNode>(MayRaiseBinaryNodeGen.create(func, errorResult)),
2346-
true);
2347-
break;
2348-
case 3:
2349-
rootNode = new BuiltinFunctionRootNode(lang, ternaryBuiltin,
2350-
new StandaloneBuiltinFactory<PythonTernaryBuiltinNode>(MayRaiseTernaryNodeGen.create(func, errorResult)),
2351-
true);
2352-
break;
2353-
default:
2354-
break;
2355-
}
2356-
}
2357-
if (rootNode == null) {
2358-
rootNode = new BuiltinFunctionRootNode(lang, varargsBuiltin,
2359-
new StandaloneBuiltinFactory<PythonBuiltinNode>(new MayRaiseNode(func, errorResult)),
2360-
true);
2361-
}
2334+
assert errorResult instanceof Integer || errorResult instanceof Long || errorResult instanceof Double || errorResult == PNone.NONE || InteropLibrary.getUncached().isNull(errorResult) : "invalid wrap";
2335+
PythonLanguage lang = PythonLanguage.getCurrent();
2336+
RootNode rootNode = new MayRaiseNode(lang, func.getSignature(), func.getCallTarget(), errorResult);
2337+
return PythonUtils.getOrCreateCallTarget(rootNode);
2338+
}
23622339

2363-
return factory().createBuiltinFunction(func.getName(), null, 0, PythonUtils.getOrCreateCallTarget(rootNode));
2340+
@Specialization
2341+
@TruffleBoundary
2342+
Object make(PFunction func, Object errorResult) {
2343+
RootCallTarget wrappedCt = func.getCallTarget();
2344+
WeakReference<RootCallTarget> wrapperCtRef = weakCallTargetMap.get(wrappedCt);
2345+
RootCallTarget wrapperCt = null;
2346+
if (wrapperCtRef != null) {
2347+
wrapperCt = wrapperCtRef.get();
2348+
}
2349+
if (wrapperCt == null) {
2350+
wrapperCt = createWrapperCt(func, errorResult);
2351+
weakCallTargetMap.put(wrappedCt, new WeakReference<>(wrapperCt));
2352+
}
2353+
PCode wrappedCode = func.getCode();
2354+
PCode wrapperCode = factory().createCode(PythonBuiltinClassType.PCode, wrapperCt, func.getSignature(),
2355+
0, 0, 0,
2356+
new byte[0], new Object[0], new Object[0],
2357+
new Object[0], new Object[0], new Object[0],
2358+
wrappedCode.getFilename(), wrappedCode.getName(), 0,
2359+
new byte[0]);
2360+
return factory().createFunction(func.getName(), func.getQualname(), func.getEnclosingClassName(),
2361+
wrapperCode, func.getGlobals(), func.getDefaults(), func.getKwDefaults(),
2362+
func.getClosure(), func.getCodeStableAssumption(), func.getDefaultsStableAssumption());
23642363
}
23652364
}
23662365

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/CExtNodes.java

Lines changed: 22 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import com.oracle.graal.python.builtins.objects.floats.PFloat;
8989
import com.oracle.graal.python.builtins.objects.function.PFunction;
9090
import com.oracle.graal.python.builtins.objects.function.PKeyword;
91+
import com.oracle.graal.python.builtins.objects.function.Signature;
9192
import com.oracle.graal.python.builtins.objects.getsetdescriptor.DescriptorDeleteMarker;
9293
import com.oracle.graal.python.builtins.objects.ints.PInt;
9394
import com.oracle.graal.python.builtins.objects.module.PythonModule;
@@ -103,21 +104,21 @@
103104
import com.oracle.graal.python.nodes.PGuards;
104105
import com.oracle.graal.python.nodes.PNodeWithContext;
105106
import com.oracle.graal.python.nodes.PRaiseNode;
107+
import com.oracle.graal.python.nodes.PRootNode;
106108
import com.oracle.graal.python.nodes.SpecialMethodNames;
107109
import com.oracle.graal.python.nodes.argument.CreateArgumentsNode;
108110
import com.oracle.graal.python.nodes.argument.ReadArgumentNode;
109111
import com.oracle.graal.python.nodes.argument.ReadVarArgsNode;
110112
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
111113
import com.oracle.graal.python.nodes.call.CallNode;
114+
import com.oracle.graal.python.nodes.call.CallTargetInvokeNode;
112115
import com.oracle.graal.python.nodes.call.FunctionInvokeNode;
113116
import com.oracle.graal.python.nodes.call.special.CallBinaryMethodNode;
114117
import com.oracle.graal.python.nodes.call.special.CallTernaryMethodNode;
115118
import com.oracle.graal.python.nodes.call.special.CallUnaryMethodNode;
116119
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode.LookupAndCallUnaryDynamicNode;
117120
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
118121
import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
119-
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
120-
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
121122
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
122123
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
123124
import com.oracle.graal.python.nodes.object.GetClassNode;
@@ -137,6 +138,9 @@
137138
import com.oracle.truffle.api.Assumption;
138139
import com.oracle.truffle.api.CompilerAsserts;
139140
import com.oracle.truffle.api.CompilerDirectives;
141+
import com.oracle.truffle.api.RootCallTarget;
142+
import com.oracle.truffle.api.Truffle;
143+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
140144
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
141145
import com.oracle.truffle.api.TruffleLogger;
142146
import com.oracle.truffle.api.dsl.Cached;
@@ -157,6 +161,7 @@
157161
import com.oracle.truffle.api.interop.UnsupportedMessageException;
158162
import com.oracle.truffle.api.interop.UnsupportedTypeException;
159163
import com.oracle.truffle.api.library.CachedLibrary;
164+
import com.oracle.truffle.api.nodes.DirectCallNode;
160165
import com.oracle.truffle.api.nodes.ExplodeLoop;
161166
import com.oracle.truffle.api.nodes.InvalidAssumptionException;
162167
import com.oracle.truffle.api.nodes.Node;
@@ -2527,145 +2532,26 @@ public static PCallCapiFunction getUncached() {
25272532
}
25282533
}
25292534

2530-
// -----------------------------------------------------------------------------------------------------------------
2531-
@Builtin(minNumOfPositionalArgs = 1)
2532-
public abstract static class MayRaiseUnaryNode extends PythonUnaryBuiltinNode {
2533-
@Child private CreateArgumentsNode createArgsNode;
2534-
@Child private FunctionInvokeNode invokeNode;
2535-
@Child private TransformExceptionToNativeNode transformExceptionToNativeNode;
2536-
2537-
private final PFunction func;
2538-
private final Object errorResult;
2539-
2540-
public MayRaiseUnaryNode(PFunction func, Object errorResult) {
2541-
this.createArgsNode = CreateArgumentsNode.create();
2542-
this.func = func;
2543-
this.invokeNode = FunctionInvokeNode.create(func);
2544-
this.errorResult = errorResult;
2545-
}
2546-
2547-
@Specialization
2548-
Object doit(VirtualFrame frame, Object argument) {
2549-
try {
2550-
Object[] arguments = createArgsNode.execute(func, new Object[]{argument});
2551-
return invokeNode.execute(frame, arguments);
2552-
} catch (PException e) {
2553-
// transformExceptionToNativeNode acts as a branch profile
2554-
ensureTransformExceptionToNativeNode().execute(frame, e);
2555-
return errorResult;
2556-
}
2557-
}
2558-
2559-
private TransformExceptionToNativeNode ensureTransformExceptionToNativeNode() {
2560-
if (transformExceptionToNativeNode == null) {
2561-
CompilerDirectives.transferToInterpreterAndInvalidate();
2562-
transformExceptionToNativeNode = insert(TransformExceptionToNativeNodeGen.create());
2563-
}
2564-
return transformExceptionToNativeNode;
2565-
}
2566-
}
2567-
2568-
// -----------------------------------------------------------------------------------------------------------------
2569-
@Builtin(minNumOfPositionalArgs = 2)
2570-
public abstract static class MayRaiseBinaryNode extends PythonBinaryBuiltinNode {
2571-
@Child private CreateArgumentsNode createArgsNode;
2572-
@Child private FunctionInvokeNode invokeNode;
2573-
@Child private TransformExceptionToNativeNode transformExceptionToNativeNode;
2574-
2575-
private final PFunction func;
2576-
private final Object errorResult;
2577-
2578-
public MayRaiseBinaryNode(PFunction func, Object errorResult) {
2579-
this.createArgsNode = CreateArgumentsNode.create();
2580-
this.func = func;
2581-
this.invokeNode = FunctionInvokeNode.create(func);
2582-
this.errorResult = errorResult;
2583-
}
2584-
2585-
@Specialization
2586-
Object doit(VirtualFrame frame, Object arg1, Object arg2) {
2587-
try {
2588-
Object[] arguments = createArgsNode.execute(func, new Object[]{arg1, arg2});
2589-
return invokeNode.execute(frame, arguments);
2590-
} catch (PException e) {
2591-
// transformExceptionToNativeNode acts as a branch profile
2592-
ensureTransformExceptionToNativeNode().execute(frame, e);
2593-
return errorResult;
2594-
}
2595-
}
2596-
2597-
private TransformExceptionToNativeNode ensureTransformExceptionToNativeNode() {
2598-
if (transformExceptionToNativeNode == null) {
2599-
CompilerDirectives.transferToInterpreterAndInvalidate();
2600-
transformExceptionToNativeNode = insert(TransformExceptionToNativeNodeGen.create());
2601-
}
2602-
return transformExceptionToNativeNode;
2603-
}
2604-
}
2605-
2606-
// -----------------------------------------------------------------------------------------------------------------
2607-
@Builtin(minNumOfPositionalArgs = 3)
2608-
public abstract static class MayRaiseTernaryNode extends PythonTernaryBuiltinNode {
2609-
@Child private CreateArgumentsNode createArgsNode;
2610-
@Child private FunctionInvokeNode invokeNode;
2611-
@Child private TransformExceptionToNativeNode transformExceptionToNativeNode;
2612-
2613-
private final PFunction func;
2614-
private final Object errorResult;
2615-
2616-
public MayRaiseTernaryNode(PFunction func, Object errorResult) {
2617-
this.createArgsNode = CreateArgumentsNode.create();
2618-
this.func = func;
2619-
this.invokeNode = FunctionInvokeNode.create(func);
2620-
this.errorResult = errorResult;
2621-
}
2622-
2623-
@Specialization
2624-
Object doit(VirtualFrame frame, Object arg1, Object arg2, Object arg3) {
2625-
try {
2626-
Object[] arguments = createArgsNode.execute(func, new Object[]{arg1, arg2, arg3});
2627-
return invokeNode.execute(frame, arguments);
2628-
} catch (PException e) {
2629-
// transformExceptionToNativeNode acts as a branch profile
2630-
ensureTransformExceptionToNativeNode().execute(frame, e);
2631-
return errorResult;
2632-
}
2633-
}
2634-
2635-
private TransformExceptionToNativeNode ensureTransformExceptionToNativeNode() {
2636-
if (transformExceptionToNativeNode == null) {
2637-
CompilerDirectives.transferToInterpreterAndInvalidate();
2638-
transformExceptionToNativeNode = insert(TransformExceptionToNativeNodeGen.create());
2639-
}
2640-
return transformExceptionToNativeNode;
2641-
}
2642-
}
2643-
26442535
// -----------------------------------------------------------------------------------------------------------------
26452536
@Builtin(takesVarArgs = true)
2646-
public static class MayRaiseNode extends PythonBuiltinNode {
2647-
@Child private FunctionInvokeNode invokeNode;
2648-
@Child private ReadVarArgsNode readVarargsNode;
2649-
@Child private CreateArgumentsNode createArgsNode;
2537+
public static class MayRaiseNode extends PRootNode {
2538+
@Child private DirectCallNode callNode;
26502539
@Child private TransformExceptionToNativeNode transformExceptionToNativeNode;
26512540

2652-
private final PFunction func;
2541+
private final Signature signature;
26532542
private final Object errorResult;
26542543

2655-
public MayRaiseNode(PFunction callable, Object errorResult) {
2656-
this.readVarargsNode = ReadVarArgsNode.create(0, true);
2657-
this.createArgsNode = CreateArgumentsNode.create();
2658-
this.func = callable;
2659-
this.invokeNode = FunctionInvokeNode.create(callable);
2544+
public MayRaiseNode(PythonLanguage lang, Signature sign, RootCallTarget ct, Object errorResult) {
2545+
super(lang);
2546+
this.signature = sign;
2547+
this.callNode = Truffle.getRuntime().createDirectCallNode(ct);
26602548
this.errorResult = errorResult;
26612549
}
26622550

26632551
@Override
26642552
public final Object execute(VirtualFrame frame) {
2665-
Object[] args = readVarargsNode.executeObjectArray(frame);
26662553
try {
2667-
Object[] arguments = createArgsNode.execute(func, args);
2668-
return invokeNode.execute(frame, arguments);
2554+
return callNode.call(frame.getArguments());
26692555
} catch (PException e) {
26702556
// transformExceptionToNativeNode acts as a branch profile
26712557
ensureTransformExceptionToNativeNode().execute(frame, e);
@@ -2682,8 +2568,13 @@ private TransformExceptionToNativeNode ensureTransformExceptionToNativeNode() {
26822568
}
26832569

26842570
@Override
2685-
protected ReadArgumentNode[] getArguments() {
2686-
throw new IllegalAccessError();
2571+
public Signature getSignature() {
2572+
return signature;
2573+
}
2574+
2575+
@Override
2576+
public boolean isPythonInternal() {
2577+
return true;
26872578
}
26882579
}
26892580

0 commit comments

Comments
 (0)