Skip to content

Commit 039b10a

Browse files
committed
compare the bases' __new__ to the current new like CPython does
1 parent 30cd4b3 commit 039b10a

File tree

3 files changed

+115
-85
lines changed

3 files changed

+115
-85
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/builtins/BuiltinCallNode.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,15 @@
4242

4343
import com.oracle.graal.python.builtins.objects.function.PKeyword;
4444
import com.oracle.graal.python.nodes.argument.ReadArgumentNode;
45+
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
4546
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4647
import com.oracle.truffle.api.frame.VirtualFrame;
4748
import com.oracle.truffle.api.nodes.Node;
4849

4950
public abstract class BuiltinCallNode extends Node {
5051
public abstract Object execute(VirtualFrame frame);
5152

52-
protected abstract Node getNode();
53+
protected abstract PythonBuiltinBaseNode getNode();
5354

5455
public static final class BuiltinAnyCallNode extends BuiltinCallNode {
5556
@Child PythonBuiltinNode node;
@@ -64,7 +65,7 @@ public Object execute(VirtualFrame frame) {
6465
}
6566

6667
@Override
67-
protected Node getNode() {
68+
protected PythonBuiltinBaseNode getNode() {
6869
return node;
6970
}
7071
}
@@ -84,7 +85,7 @@ public Object execute(VirtualFrame frame) {
8485
}
8586

8687
@Override
87-
protected Node getNode() {
88+
protected PythonBuiltinBaseNode getNode() {
8889
return node;
8990
}
9091
}
@@ -106,7 +107,7 @@ public Object execute(VirtualFrame frame) {
106107
}
107108

108109
@Override
109-
protected Node getNode() {
110+
protected PythonBuiltinBaseNode getNode() {
110111
return node;
111112
}
112113
}
@@ -130,7 +131,7 @@ public Object execute(VirtualFrame frame) {
130131
}
131132

132133
@Override
133-
protected Node getNode() {
134+
protected PythonBuiltinBaseNode getNode() {
134135
return node;
135136
}
136137
}
@@ -156,7 +157,7 @@ public Object execute(VirtualFrame frame) {
156157
}
157158

158159
@Override
159-
protected Node getNode() {
160+
protected PythonBuiltinBaseNode getNode() {
160161
return node;
161162
}
162163
}
@@ -180,7 +181,7 @@ public Object execute(VirtualFrame frame) {
180181
}
181182

182183
@Override
183-
protected Node getNode() {
184+
protected PythonBuiltinBaseNode getNode() {
184185
return node;
185186
}
186187
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/builtins/SlotWrapper.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
*/
4141
package com.oracle.graal.python.nodes.function.builtins;
4242

43+
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
4344
import com.oracle.truffle.api.frame.VirtualFrame;
44-
import com.oracle.truffle.api.nodes.Node;
4545

4646
public abstract class SlotWrapper extends BuiltinCallNode {
4747
@Child BuiltinCallNode func;
@@ -56,7 +56,7 @@ public Object execute(VirtualFrame frame) {
5656
}
5757

5858
@Override
59-
protected Node getNode() {
59+
protected PythonBuiltinBaseNode getNode() {
6060
return func.getNode();
6161
}
6262
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/builtins/WrapTpNew.java

Lines changed: 105 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,20 @@
4444
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4545
import com.oracle.graal.python.builtins.objects.cext.PythonNativeClass;
4646
import com.oracle.graal.python.builtins.objects.function.PArguments;
47+
import com.oracle.graal.python.builtins.objects.function.PBuiltinFunction;
4748
import com.oracle.graal.python.builtins.objects.type.PythonAbstractClass;
4849
import com.oracle.graal.python.builtins.objects.type.PythonBuiltinClass;
49-
import com.oracle.graal.python.builtins.objects.type.TypeNodes.GetMroNode;
50+
import com.oracle.graal.python.builtins.objects.type.TypeNodes.GetBaseClassesNode;
5051
import com.oracle.graal.python.builtins.objects.type.TypeNodes.IsTypeNode;
52+
import com.oracle.graal.python.builtins.objects.type.TypeNodesFactory.GetBaseClassesNodeGen;
5153
import com.oracle.graal.python.nodes.PRaiseNode;
54+
import com.oracle.graal.python.nodes.SpecialMethodNames;
55+
import com.oracle.graal.python.nodes.attributes.LookupAttributeInMRONode;
5256
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
57+
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5358
import com.oracle.truffle.api.CompilerDirectives;
5459
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
60+
import com.oracle.truffle.api.dsl.NodeFactory;
5561
import com.oracle.truffle.api.frame.VirtualFrame;
5662
import com.oracle.truffle.api.nodes.ExplodeLoop;
5763
import com.oracle.truffle.api.nodes.NodeCost;
@@ -62,15 +68,16 @@
6268
public final class WrapTpNew extends SlotWrapper {
6369
@Child IsTypeNode isType;
6470
@Child IsSubtypeNode isSubtype;
65-
@Child GetMroNode getMro;
71+
@Child GetBaseClassesNode getBases;
6672
@Child PRaiseNode raiseNode;
73+
@Child LookupAttributeInMRONode lookupNewNode;
6774
@CompilationFinal byte state = 0;
6875
@CompilationFinal PythonBuiltinClassType owner;
6976

7077
private static final short NOT_SUBTP_STATE = 0b10000000;
7178
private static final short NOT_CLASS_STATE = 0b01000000;
7279
private static final short IS_UNSAFE_STATE = 0b00100000;
73-
private static final short NOTCONSTANT_MRO = 0b00010000;
80+
private static final short NONCONSTANT_MRO = 0b00010000;
7481
private static final short MRO_LENGTH_MASK = 0b00001111;
7582

7683
public WrapTpNew(BuiltinCallNode func) {
@@ -86,108 +93,130 @@ public Object execute(VirtualFrame frame) {
8693
CompilerDirectives.transferToInterpreterAndInvalidate();
8794
throw new IllegalStateException(getOwner().getName() + ".__new__ called without arguments");
8895
}
89-
if (isType == null) {
90-
CompilerDirectives.transferToInterpreterAndInvalidate();
91-
reportPolymorphicSpecialize();
92-
isType = insert(IsTypeNode.create());
93-
}
94-
if (!isType.execute(arg0)) {
95-
if ((state & NOT_CLASS_STATE) == 0) {
96+
if (arg0 != getOwner()) {
97+
if (isType == null) {
9698
CompilerDirectives.transferToInterpreterAndInvalidate();
9799
reportPolymorphicSpecialize();
98-
state |= NOT_CLASS_STATE;
100+
isType = insert(IsTypeNode.create());
99101
}
100-
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
101-
"%s.__new__(X): X is not a type object (%N)", getOwner().getName(), arg0);
102-
}
103-
if (isSubtype == null) {
104-
CompilerDirectives.transferToInterpreterAndInvalidate();
105-
reportPolymorphicSpecialize();
106-
isSubtype = insert(IsSubtypeNode.create());
107-
}
108-
if (!isSubtype.execute(arg0, getOwner())) {
109-
if ((state & NOT_SUBTP_STATE) == 0) {
102+
if (!isType.execute(arg0)) {
103+
if ((state & NOT_CLASS_STATE) == 0) {
104+
CompilerDirectives.transferToInterpreterAndInvalidate();
105+
reportPolymorphicSpecialize();
106+
state |= NOT_CLASS_STATE;
107+
}
108+
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
109+
"%s.__new__(X): X is not a type object (%N)", getOwner().getName(), arg0);
110+
}
111+
if (isSubtype == null) {
110112
CompilerDirectives.transferToInterpreterAndInvalidate();
111113
reportPolymorphicSpecialize();
112-
state |= NOT_SUBTP_STATE;
114+
isSubtype = insert(IsSubtypeNode.create());
113115
}
114-
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
115-
"%s.__new__(%N): %N is not a subtype of %s",
116-
getOwner().getName(), arg0, arg0, getOwner().getName());
117-
}
118-
if (getMro == null) {
119-
CompilerDirectives.transferToInterpreterAndInvalidate();
120-
reportPolymorphicSpecialize();
121-
getMro = insert(GetMroNode.create());
122-
}
123-
// TODO (tfel): not quite correct, since we should just be walking the bases, not the entire
124-
// MRO
125-
PythonAbstractClass[] mro = getMro.execute(arg0);
126-
if ((state & MRO_LENGTH_MASK) == 0) {
127-
CompilerDirectives.transferToInterpreterAndInvalidate();
128-
int length = mro.length;
129-
if (length < MRO_LENGTH_MASK) {
130-
state |= length;
131-
} else {
132-
state |= MRO_LENGTH_MASK;
116+
if (!isSubtype.execute(arg0, getOwner())) {
117+
if ((state & NOT_SUBTP_STATE) == 0) {
118+
CompilerDirectives.transferToInterpreterAndInvalidate();
119+
reportPolymorphicSpecialize();
120+
state |= NOT_SUBTP_STATE;
121+
}
122+
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
123+
"%s.__new__(%N): %N is not a subtype of %s",
124+
getOwner().getName(), arg0, arg0, getOwner().getName());
133125
}
134-
}
135-
boolean isSafeNew = true;
136-
if ((state & MRO_LENGTH_MASK) == mro.length) {
137-
// cached mro, explode loop
138-
isSafeNew = checkSafeNew(mro, state & MRO_LENGTH_MASK);
139-
} else {
140-
if ((state & NOTCONSTANT_MRO) == 0) {
126+
if (getBases == null) {
141127
CompilerDirectives.transferToInterpreterAndInvalidate();
142128
reportPolymorphicSpecialize();
143-
state |= NOTCONSTANT_MRO;
129+
getBases = insert(GetBaseClassesNodeGen.create());
144130
}
145-
// mro too long to cache or different from the cached one, no explode loop
146-
isSafeNew = checkSafeNew(mro);
147-
}
148-
if (!isSafeNew) {
149-
if ((state & IS_UNSAFE_STATE) == 0) {
131+
PythonAbstractClass[] bases = getBases.execute(arg0);
132+
if ((state & MRO_LENGTH_MASK) == 0) {
150133
CompilerDirectives.transferToInterpreterAndInvalidate();
151-
reportPolymorphicSpecialize();
152-
state |= IS_UNSAFE_STATE;
134+
int length = bases.length;
135+
if (length < MRO_LENGTH_MASK) {
136+
state |= length;
137+
} else {
138+
state |= MRO_LENGTH_MASK;
139+
}
140+
}
141+
boolean isSafeNew = true;
142+
if ((state & MRO_LENGTH_MASK) == bases.length) {
143+
// cached mro, explode loop
144+
isSafeNew = checkSafeNew(bases, state & MRO_LENGTH_MASK);
145+
} else {
146+
if ((state & NONCONSTANT_MRO) == 0) {
147+
CompilerDirectives.transferToInterpreterAndInvalidate();
148+
reportPolymorphicSpecialize();
149+
state |= NONCONSTANT_MRO;
150+
}
151+
// mro too long to cache or different from the cached one, no explode loop
152+
isSafeNew = checkSafeNew(bases);
153+
}
154+
if (!isSafeNew) {
155+
if ((state & IS_UNSAFE_STATE) == 0) {
156+
CompilerDirectives.transferToInterpreterAndInvalidate();
157+
reportPolymorphicSpecialize();
158+
state |= IS_UNSAFE_STATE;
159+
}
160+
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
161+
"%s.__new__(%N) is not safe, use %N.__new__()",
162+
getOwner().getName(), arg0, arg0);
153163
}
154-
throw getRaiseNode().raise(PythonBuiltinClassType.TypeError,
155-
"%s.__new__(%N) is not safe, use %N.__new__()",
156-
getOwner().getName(), arg0, arg0);
157164
}
158165
return super.execute(frame);
159166
}
160167

161168
@ExplodeLoop
162-
private boolean checkSafeNew(PythonAbstractClass[] mro, int length) {
169+
private boolean checkSafeNew(PythonAbstractClass[] bases, int length) {
163170
for (int i = 0; i < length; i++) {
164-
PythonAbstractClass base = mro[i];
165-
if (base instanceof PythonBuiltinClass) {
166-
// TODO: tfel not correct, since the base may not be overriding __new__
167-
return ((PythonBuiltinClass) base).getType() == getOwner();
168-
} else if (PythonNativeClass.isInstance(base)) {
169-
// should have called the native tp_new in any case
170-
return false;
171+
byte safe = isSafe(bases, i);
172+
if (safe != -1) {
173+
return safe == 0 ? false : true;
171174
}
172175
}
173176
CompilerDirectives.transferToInterpreterAndInvalidate();
174177
throw new IllegalStateException("there is no non-heap type in the mro, broken class");
175178
}
176179

177-
private boolean checkSafeNew(PythonAbstractClass[] mro) {
178-
for (int i = 0; i < mro.length; i++) {
179-
PythonAbstractClass base = mro[i];
180-
if (base instanceof PythonBuiltinClass) {
181-
return ((PythonBuiltinClass) base).getType() == getOwner();
182-
} else if (PythonNativeClass.isInstance(base)) {
183-
// should have called the native tp_new in any case
184-
return false;
180+
private boolean checkSafeNew(PythonAbstractClass[] bases) {
181+
for (int i = 0; i < bases.length; i++) {
182+
byte safe = isSafe(bases, i);
183+
if (safe != -1) {
184+
return safe == 0 ? false : true;
185185
}
186186
}
187187
CompilerDirectives.transferToInterpreterAndInvalidate();
188188
throw new IllegalStateException("there is no non-heap type in the mro, broken class");
189189
}
190190

191+
private byte isSafe(PythonAbstractClass[] mro, int i) {
192+
PythonAbstractClass base = mro[i];
193+
if (base instanceof PythonBuiltinClass) {
194+
if (((PythonBuiltinClass) base).getType() == getOwner()) {
195+
return 1;
196+
} else {
197+
if (lookupNewNode == null) {
198+
CompilerDirectives.transferToInterpreterAndInvalidate();
199+
lookupNewNode = insert(LookupAttributeInMRONode.create(SpecialMethodNames.__NEW__));
200+
}
201+
Object newMethod = lookupNewNode.execute(base);
202+
if (newMethod instanceof PBuiltinFunction) {
203+
NodeFactory<? extends PythonBuiltinBaseNode> factory = ((PBuiltinFunction) newMethod).getBuiltinNodeFactory();
204+
if (factory != null) {
205+
return factory.getNodeClass().isAssignableFrom(getNode().getClass()) ? (byte)1 : (byte)0;
206+
}
207+
}
208+
// we explicitly allow non-Java builtin functions to pass, since a
209+
// PythonBuiltinClass with a non-java function is explicitly written in the core to
210+
// allow this
211+
return 1;
212+
}
213+
} else if (PythonNativeClass.isInstance(base)) {
214+
// should have called the native tp_new in any case
215+
return 0;
216+
}
217+
return -1;
218+
}
219+
191220
private final PRaiseNode getRaiseNode() {
192221
if (raiseNode == null) {
193222
CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -228,7 +257,7 @@ public NodeCost getCost() {
228257
} else if ((state & ~MRO_LENGTH_MASK) == 0) {
229258
// no error states, single mro
230259
return NodeCost.MONOMORPHIC;
231-
} else if (((state & ~MRO_LENGTH_MASK) & NOTCONSTANT_MRO) == NOTCONSTANT_MRO) {
260+
} else if (((state & ~MRO_LENGTH_MASK) & NONCONSTANT_MRO) == NONCONSTANT_MRO) {
232261
// no error states, multiple mros
233262
return NodeCost.POLYMORPHIC;
234263
} else {

0 commit comments

Comments
 (0)