Skip to content

Commit 2632de5

Browse files
committed
[GR-46237] Use cached specializations for NbNumbers calls
PullRequest: graalpython/2793
2 parents dd39d5d + b6e766a commit 2632de5

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/lib/GetMethodsFlagsNode.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,25 @@
4242

4343
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4444
import com.oracle.graal.python.builtins.objects.cext.PythonAbstractNativeObject;
45+
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes;
4546
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.PCallCapiFunction;
4647
import com.oracle.graal.python.builtins.objects.cext.capi.NativeCAPISymbol;
48+
import com.oracle.graal.python.builtins.objects.cext.capi.NativeMember;
49+
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetItem;
50+
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageSetItem;
51+
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodesFactory.HashingStorageGetItemNodeGen;
52+
import com.oracle.graal.python.builtins.objects.dict.PDict;
4753
import com.oracle.graal.python.builtins.objects.type.PythonManagedClass;
54+
import com.oracle.graal.python.runtime.PythonContext;
55+
import com.oracle.truffle.api.Assumption;
56+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
4857
import com.oracle.truffle.api.dsl.Cached;
4958
import com.oracle.truffle.api.dsl.Fallback;
5059
import com.oracle.truffle.api.dsl.GenerateInline;
5160
import com.oracle.truffle.api.dsl.GenerateUncached;
5261
import com.oracle.truffle.api.dsl.Specialization;
5362
import com.oracle.truffle.api.nodes.Node;
63+
import com.oracle.truffle.api.object.HiddenKey;
5464

5565
/**
5666
* Retrieve slots occupation of `cls->tp_as_number`, `cls->tp_as_sequence` and `cls->tp_as_mapping`
@@ -72,13 +82,44 @@ protected static long pythonclasstype(PythonBuiltinClassType cls) {
7282
return cls.getMethodsFlags();
7383
}
7484

75-
@Specialization
76-
static long doNative(PythonAbstractNativeObject cls,
77-
@Cached PCallCapiFunction callCapiFunction) {
78-
Long flags = (Long) callCapiFunction.call(NativeCAPISymbol.FUN_GET_METHODS_FLAGS, cls.getPtr());
85+
public static final HiddenKey METHODS_FLAGS = new HiddenKey("__methods_flags__");
86+
87+
@TruffleBoundary
88+
private static long populateMethodsFlags(PythonAbstractNativeObject cls, PDict dict) {
89+
Long flags = (Long) PCallCapiFunction.getUncached().call(NativeCAPISymbol.FUN_GET_METHODS_FLAGS, cls.getPtr());
90+
HashingStorageSetItem.executeUncached(dict.getDictStorage(), METHODS_FLAGS, flags);
7991
return flags;
8092
}
8193

94+
protected static long getMethodsFlags(PythonAbstractNativeObject cls) {
95+
return doNative(cls, CExtNodes.GetTypeMemberNode.getUncached(), HashingStorageGetItemNodeGen.getUncached());
96+
}
97+
98+
// The assumption should hold unless `PyType_Modified` is called.
99+
protected static Assumption nativeAssumption(PythonAbstractNativeObject cls) {
100+
return PythonContext.get(null).getNativeClassStableAssumption(cls, true).getAssumption();
101+
}
102+
103+
@Specialization(guards = "cachedCls == cls", limit = "5", assumptions = "nativeAssumption(cachedCls)")
104+
static long doNativeCached(@SuppressWarnings("unused") PythonAbstractNativeObject cls,
105+
@SuppressWarnings("unused") @Cached("cls") PythonAbstractNativeObject cachedCls,
106+
@Cached("getMethodsFlags(cls)") long flags) {
107+
return flags;
108+
}
109+
110+
@Specialization(replaces = "doNativeCached")
111+
static long doNative(PythonAbstractNativeObject cls,
112+
@Cached CExtNodes.GetTypeMemberNode getTpDictNode,
113+
@Cached HashingStorageGetItem getItem) {
114+
// classes must have tp_dict since they are set during PyType_Ready
115+
PDict dict = (PDict) getTpDictNode.execute(cls, NativeMember.TP_DICT);
116+
Object f = getItem.execute(null, dict.getDictStorage(), METHODS_FLAGS);
117+
if (f == null) {
118+
return populateMethodsFlags(cls, dict);
119+
}
120+
return (Long) f;
121+
}
122+
82123
@Fallback
83124
protected static long zero(@SuppressWarnings("unused") Object cls) {
84125
return 0;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallNbNumbersBinaryNode.java

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,27 @@ protected abstract static class PyNumberAddNode extends LookupAndCallNbNumbersBi
9090
super(handlerFactory);
9191
}
9292

93-
@Specialization
93+
@Specialization(guards = {"left.getClass() == cachedLeftClass", "right.getClass() == cachedRightClass"}, limit = "5")
94+
Object addC(VirtualFrame frame, Object left, Object right,
95+
@Bind("this") Node node,
96+
@SuppressWarnings("unused") @Cached("left.getClass()") Class<?> cachedLeftClass,
97+
@SuppressWarnings("unused") @Cached("right.getClass()") Class<?> cachedRightClass,
98+
@Cached InlinedGetClassNode getClassNode,
99+
@Cached("create(Add)") LookupSpecialMethodSlotNode getattr,
100+
@Cached GetMethodsFlagsNode getMethodsFlagsNode,
101+
@Cached InlinedGetClassNode getlClassNode,
102+
@Cached InlinedGetClassNode getrClassNode,
103+
@Cached InlinedConditionProfile p1,
104+
@Cached InlinedConditionProfile p2,
105+
@Cached InlinedConditionProfile p3,
106+
@Cached BinaryOp1Node binaryOp1Node,
107+
@Cached Slot1BINFULLNode slot1BINFULLNode) {
108+
return add(frame, left, right, node, getClassNode, getattr, getMethodsFlagsNode,
109+
getlClassNode, getrClassNode, p1, p2, p3,
110+
binaryOp1Node, slot1BINFULLNode);
111+
}
112+
113+
@Specialization(replaces = "addC")
94114
Object add(VirtualFrame frame, Object left, Object right,
95115
@Bind("this") Node node,
96116
@Cached InlinedGetClassNode getClassNode,
@@ -145,7 +165,26 @@ protected abstract static class PyNumberMultiplyNode extends LookupAndCallNbNumb
145165
super(handlerFactory);
146166
}
147167

148-
@Specialization
168+
@Specialization(guards = {"left.getClass() == cachedLeftClass", "right.getClass() == cachedRightClass"}, limit = "5")
169+
Object mulC(VirtualFrame frame, Object left, Object right,
170+
@Bind("this") Node node,
171+
@SuppressWarnings("unused") @Cached("left.getClass()") Class<?> cachedLeftClass,
172+
@SuppressWarnings("unused") @Cached("right.getClass()") Class<?> cachedRightClass,
173+
@Cached InlinedGetClassNode getClassNode,
174+
@Cached("create(Mul)") LookupSpecialMethodSlotNode getattr,
175+
@Cached GetMethodsFlagsNode getMethodsFlagsNode,
176+
@Cached InlinedGetClassNode getlClassNode,
177+
@Cached InlinedGetClassNode getrClassNode,
178+
@Cached InlinedConditionProfile p1,
179+
@Cached InlinedConditionProfile p2,
180+
@Cached InlinedConditionProfile p3,
181+
@Cached BinaryOp1Node binaryOp1Node,
182+
@Cached Slot1BINFULLNode slot1BINFULLNode) {
183+
return mul(frame, left, right, node, getClassNode, getattr, getMethodsFlagsNode, getlClassNode,
184+
getrClassNode, p1, p2, p3, binaryOp1Node, slot1BINFULLNode);
185+
}
186+
187+
@Specialization(replaces = "mulC")
149188
Object mul(VirtualFrame frame, Object left, Object right,
150189
@Bind("this") Node node,
151190
@Cached InlinedGetClassNode getClassNode,
@@ -209,7 +248,23 @@ protected abstract static class BinaryOpNode extends LookupAndCallNbNumbersBinar
209248
this.rslot = rslot;
210249
}
211250

212-
@Specialization
251+
@Specialization(guards = {"left.getClass() == cachedLeftClass", "right.getClass() == cachedRightClass"}, limit = "5")
252+
Object binaryOpC(VirtualFrame frame, Object left, Object right,
253+
@Bind("this") Node node,
254+
@SuppressWarnings("unused") @Cached("left.getClass()") Class<?> cachedLeftClass,
255+
@SuppressWarnings("unused") @Cached("right.getClass()") Class<?> cachedRightClass,
256+
@Cached GetMethodsFlagsNode getMethodsFlagsNode,
257+
@Cached InlinedGetClassNode getlClassNode,
258+
@Cached InlinedGetClassNode getrClassNode,
259+
@Cached InlinedConditionProfile p1,
260+
@Cached InlinedConditionProfile p2,
261+
@Cached BinaryOp1Node binaryOp1Node,
262+
@Cached Slot1BINFULLNode slot1BINFULLNode) {
263+
return binaryOp(frame, left, right, node, getMethodsFlagsNode, getlClassNode, getrClassNode, p1, p2,
264+
binaryOp1Node, slot1BINFULLNode);
265+
}
266+
267+
@Specialization(replaces = "binaryOpC")
213268
Object binaryOp(VirtualFrame frame, Object left, Object right,
214269
@Bind("this") Node node,
215270
@Cached GetMethodsFlagsNode getMethodsFlagsNode,

0 commit comments

Comments
 (0)