Skip to content

Commit 338b352

Browse files
committed
simplify safety check for tp_new wrapper
1 parent 039b10a commit 338b352

File tree

2 files changed

+47
-98
lines changed

2 files changed

+47
-98
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/attributes/LookupAttributeInMRONode.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4545
import com.oracle.graal.python.builtins.objects.PNone;
4646
import com.oracle.graal.python.builtins.objects.type.PythonAbstractClass;
47+
import com.oracle.graal.python.builtins.objects.type.PythonClass;
4748
import com.oracle.graal.python.builtins.objects.type.TypeNodes;
4849
import com.oracle.graal.python.builtins.objects.type.TypeNodes.GetMroStorageNode;
4950
import com.oracle.graal.python.builtins.objects.type.TypeNodesFactory.IsSameTypeNodeGen;
@@ -106,7 +107,7 @@ protected Object lookup(PythonBuiltinClassType klass, Object key) {
106107
protected Object lookup(Object klass, Object key,
107108
@Cached("create()") GetMroStorageNode getMroNode,
108109
@Cached("createForceType()") ReadAttributeFromObjectNode readAttrNode) {
109-
return lookupSlow(klass, key, getMroNode, readAttrNode);
110+
return lookupSlow(klass, key, getMroNode, readAttrNode, false);
110111
}
111112
}
112113

@@ -119,14 +120,15 @@ public Object execute(Object klass, Object key) {
119120
if (klass instanceof PythonBuiltinClassType) {
120121
return findAttr(PythonLanguage.getCore(), (PythonBuiltinClassType) klass, key);
121122
} else if (klass instanceof PythonAbstractClass) {
122-
return lookupSlow(klass, key, getMroNode, readAttrNode);
123+
return lookupSlow(klass, key, getMroNode, readAttrNode, false);
123124
} else {
124125
CompilerDirectives.transferToInterpreter();
125126
throw new RuntimeException("not implemented: lookup inherited attribute from non-PythonClass");
126127
}
127128
}
128129
}
129130

131+
private final boolean skipPythonClasses;
130132
protected final String key;
131133
@CompilationFinal private ContextReference<PythonContext> contextRef;
132134
@Child private TypeNodes.IsSameTypeNode isSameTypeNode = IsSameTypeNodeGen.create();
@@ -140,12 +142,21 @@ protected PythonCore getCore() {
140142
return contextRef.get().getCore();
141143
}
142144

143-
public LookupAttributeInMRONode(String key) {
145+
public LookupAttributeInMRONode(String key, boolean skipPythonClasses) {
144146
this.key = key;
147+
this.skipPythonClasses = skipPythonClasses;
145148
}
146149

147150
public static LookupAttributeInMRONode create(String key) {
148-
return LookupAttributeInMRONodeGen.create(key);
151+
return LookupAttributeInMRONodeGen.create(key, false);
152+
}
153+
154+
/**
155+
* Specific case to facilitate lookup on native and built-in classes only. This is useful for
156+
* certain slot wrappers.
157+
*/
158+
public static LookupAttributeInMRONode createForLookupOfUnmanagedClasses(String key) {
159+
return LookupAttributeInMRONodeGen.create(key, true);
149160
}
150161

151162
/**
@@ -201,7 +212,9 @@ protected PythonClassAssumptionPair findAttrClassAndAssumptionInMRO(Object klass
201212
assert clsObj != klass : "MRO chain is incorrect: '" + klass + "' was found at position " + i;
202213
getMro(clsObj).addAttributeInMROFinalAssumption(key, attrAssumption);
203214
}
204-
215+
if (skipPythonClasses && clsObj instanceof PythonClass) {
216+
continue;
217+
}
205218
Object value = ReadAttributeFromObjectNode.getUncachedForceType().execute(clsObj, key);
206219
if (value != PNone.NO_VALUE) {
207220
return new PythonClassAssumptionPair(attrAssumption, value);
@@ -239,6 +252,9 @@ protected Object lookupConstantMRO(@SuppressWarnings("unused") Object klass,
239252
@Cached("create(mroLength)") ReadAttributeFromObjectNode[] readAttrNodes) {
240253
for (int i = 0; i < mroLength; i++) {
241254
Object kls = mro.getItemNormalized(i);
255+
if (skipPythonClasses && kls instanceof PythonClass) {
256+
continue;
257+
}
242258
Object value = readAttrNodes[i].execute(kls, key);
243259
if (value != PNone.NO_VALUE) {
244260
return value;
@@ -250,7 +266,7 @@ protected Object lookupConstantMRO(@SuppressWarnings("unused") Object klass,
250266
@Specialization(replaces = {"lookupConstantMROCached", "lookupConstantMRO"})
251267
protected Object lookup(Object klass,
252268
@Cached("createForceType()") ReadAttributeFromObjectNode readAttrNode) {
253-
return lookupSlow(klass, key, getMroNode, readAttrNode);
269+
return lookupSlow(klass, key, getMroNode, readAttrNode, skipPythonClasses);
254270
}
255271

256272
protected GetMroStorageNode ensureGetMroNode() {
@@ -265,10 +281,13 @@ protected MroSequenceStorage getMro(Object clazz) {
265281
return ensureGetMroNode().execute(clazz);
266282
}
267283

268-
private static Object lookupSlow(Object klass, Object key, GetMroStorageNode getMroNode, ReadAttributeFromObjectNode readAttrNode) {
284+
private static Object lookupSlow(Object klass, Object key, GetMroStorageNode getMroNode, ReadAttributeFromObjectNode readAttrNode, boolean skipPythonClasses) {
269285
MroSequenceStorage mro = getMroNode.execute(klass);
270286
for (int i = 0; i < mro.length(); i++) {
271287
Object kls = mro.getItemNormalized(i);
288+
if (skipPythonClasses && kls instanceof PythonClass) {
289+
continue;
290+
}
272291
Object value = readAttrNode.execute(kls, key);
273292
if (value != PNone.NO_VALUE) {
274293
return value;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/builtins/WrapTpNew.java

Lines changed: 21 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,10 @@
4242

4343
import com.oracle.graal.python.builtins.Builtin;
4444
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
45-
import com.oracle.graal.python.builtins.objects.cext.PythonNativeClass;
4645
import com.oracle.graal.python.builtins.objects.function.PArguments;
4746
import com.oracle.graal.python.builtins.objects.function.PBuiltinFunction;
48-
import com.oracle.graal.python.builtins.objects.type.PythonAbstractClass;
49-
import com.oracle.graal.python.builtins.objects.type.PythonBuiltinClass;
5047
import com.oracle.graal.python.builtins.objects.type.TypeNodes.GetBaseClassesNode;
5148
import com.oracle.graal.python.builtins.objects.type.TypeNodes.IsTypeNode;
52-
import com.oracle.graal.python.builtins.objects.type.TypeNodesFactory.GetBaseClassesNodeGen;
5349
import com.oracle.graal.python.nodes.PRaiseNode;
5450
import com.oracle.graal.python.nodes.SpecialMethodNames;
5551
import com.oracle.graal.python.nodes.attributes.LookupAttributeInMRONode;
@@ -59,7 +55,6 @@
5955
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
6056
import com.oracle.truffle.api.dsl.NodeFactory;
6157
import com.oracle.truffle.api.frame.VirtualFrame;
62-
import com.oracle.truffle.api.nodes.ExplodeLoop;
6358
import com.oracle.truffle.api.nodes.NodeCost;
6459

6560
/**
@@ -123,98 +118,33 @@ public Object execute(VirtualFrame frame) {
123118
"%s.__new__(%N): %N is not a subtype of %s",
124119
getOwner().getName(), arg0, arg0, getOwner().getName());
125120
}
126-
if (getBases == null) {
121+
// CPython walks the bases and checks that the first non-heaptype base has the new that
122+
// we're in. We have our optimizations for this lookup that the compiler can then
123+
// (hopefully) merge with the initial lookup of the new method before entering it.
124+
if (lookupNewNode == null) {
127125
CompilerDirectives.transferToInterpreterAndInvalidate();
128-
reportPolymorphicSpecialize();
129-
getBases = insert(GetBaseClassesNodeGen.create());
130-
}
131-
PythonAbstractClass[] bases = getBases.execute(arg0);
132-
if ((state & MRO_LENGTH_MASK) == 0) {
133-
CompilerDirectives.transferToInterpreterAndInvalidate();
134-
int length = bases.length;
135-
if (length < MRO_LENGTH_MASK) {
136-
state |= length;
137-
} else {
138-
state |= MRO_LENGTH_MASK;
139-
}
126+
lookupNewNode = insert(LookupAttributeInMRONode.createForLookupOfUnmanagedClasses(SpecialMethodNames.__NEW__));
140127
}
141-
boolean isSafeNew = true;
142-
if ((state & MRO_LENGTH_MASK) == bases.length) {
143-
// cached mro, explode loop
144-
isSafeNew = checkSafeNew(bases, state & MRO_LENGTH_MASK);
145-
} else {
146-
if ((state & NONCONSTANT_MRO) == 0) {
147-
CompilerDirectives.transferToInterpreterAndInvalidate();
148-
reportPolymorphicSpecialize();
149-
state |= NONCONSTANT_MRO;
150-
}
151-
// mro too long to cache or different from the cached one, no explode loop
152-
isSafeNew = checkSafeNew(bases);
153-
}
154-
if (!isSafeNew) {
155-
if ((state & IS_UNSAFE_STATE) == 0) {
156-
CompilerDirectives.transferToInterpreterAndInvalidate();
157-
reportPolymorphicSpecialize();
158-
state |= IS_UNSAFE_STATE;
159-
}
160-
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
161-
"%s.__new__(%N) is not safe, use %N.__new__()",
162-
getOwner().getName(), arg0, arg0);
163-
}
164-
}
165-
return super.execute(frame);
166-
}
167-
168-
@ExplodeLoop
169-
private boolean checkSafeNew(PythonAbstractClass[] bases, int length) {
170-
for (int i = 0; i < length; i++) {
171-
byte safe = isSafe(bases, i);
172-
if (safe != -1) {
173-
return safe == 0 ? false : true;
174-
}
175-
}
176-
CompilerDirectives.transferToInterpreterAndInvalidate();
177-
throw new IllegalStateException("there is no non-heap type in the mro, broken class");
178-
}
179-
180-
private boolean checkSafeNew(PythonAbstractClass[] bases) {
181-
for (int i = 0; i < bases.length; i++) {
182-
byte safe = isSafe(bases, i);
183-
if (safe != -1) {
184-
return safe == 0 ? false : true;
185-
}
186-
}
187-
CompilerDirectives.transferToInterpreterAndInvalidate();
188-
throw new IllegalStateException("there is no non-heap type in the mro, broken class");
189-
}
190-
191-
private byte isSafe(PythonAbstractClass[] mro, int i) {
192-
PythonAbstractClass base = mro[i];
193-
if (base instanceof PythonBuiltinClass) {
194-
if (((PythonBuiltinClass) base).getType() == getOwner()) {
195-
return 1;
196-
} else {
197-
if (lookupNewNode == null) {
198-
CompilerDirectives.transferToInterpreterAndInvalidate();
199-
lookupNewNode = insert(LookupAttributeInMRONode.create(SpecialMethodNames.__NEW__));
200-
}
201-
Object newMethod = lookupNewNode.execute(base);
202-
if (newMethod instanceof PBuiltinFunction) {
203-
NodeFactory<? extends PythonBuiltinBaseNode> factory = ((PBuiltinFunction) newMethod).getBuiltinNodeFactory();
204-
if (factory != null) {
205-
return factory.getNodeClass().isAssignableFrom(getNode().getClass()) ? (byte)1 : (byte)0;
128+
Object newMethod = lookupNewNode.execute(arg0);
129+
if (newMethod instanceof PBuiltinFunction) {
130+
NodeFactory<? extends PythonBuiltinBaseNode> factory = ((PBuiltinFunction) newMethod).getBuiltinNodeFactory();
131+
if (factory != null) {
132+
if (!factory.getNodeClass().isAssignableFrom(getNode().getClass())) {
133+
if ((state & IS_UNSAFE_STATE) == 0) {
134+
CompilerDirectives.transferToInterpreterAndInvalidate();
135+
reportPolymorphicSpecialize();
136+
state |= IS_UNSAFE_STATE;
137+
}
138+
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
139+
"%s.__new__(%N) is not safe, use %N.__new__()",
140+
getOwner().getName(), arg0, arg0);
206141
}
207142
}
208-
// we explicitly allow non-Java builtin functions to pass, since a
209-
// PythonBuiltinClass with a non-java function is explicitly written in the core to
210-
// allow this
211-
return 1;
143+
// we explicitly allow non-Java functions to pass here, since a PythonBuiltinClass
144+
// with a non-java function is explicitly written in the core to allow this
212145
}
213-
} else if (PythonNativeClass.isInstance(base)) {
214-
// should have called the native tp_new in any case
215-
return 0;
216146
}
217-
return -1;
147+
return super.execute(frame);
218148
}
219149

220150
private final PRaiseNode getRaiseNode() {

0 commit comments

Comments
 (0)