Skip to content

Commit b84242a

Browse files
committed
add more specializations to IndexNode so that the generic case of a user can still work for acceptable types that do not have an __index__ method
1 parent f7ed7b7 commit b84242a

File tree

1 file changed

+50
-8
lines changed
  • graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/builtins

1 file changed

+50
-8
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/builtins/ListNodes.java

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4848
import com.oracle.graal.python.builtins.modules.MathGuards;
4949
import com.oracle.graal.python.builtins.objects.PNone;
50+
import com.oracle.graal.python.builtins.objects.floats.PFloat;
51+
import com.oracle.graal.python.builtins.objects.ints.PInt;
5052
import com.oracle.graal.python.builtins.objects.list.PList;
5153
import com.oracle.graal.python.builtins.objects.slice.PSlice;
5254
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
@@ -58,10 +60,12 @@
5860
import com.oracle.graal.python.nodes.builtins.ListNodesFactory.ConstructListNodeGen;
5961
import com.oracle.graal.python.nodes.builtins.ListNodesFactory.CreateListFromIteratorNodeGen;
6062
import com.oracle.graal.python.nodes.builtins.ListNodesFactory.FastConstructListNodeGen;
63+
import com.oracle.graal.python.nodes.builtins.ListNodesFactory.IndexNodeGen;
6164
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
6265
import com.oracle.graal.python.nodes.control.GetIteratorNode;
6366
import com.oracle.graal.python.nodes.control.GetNextNode;
6467
import com.oracle.graal.python.nodes.object.GetClassNode;
68+
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
6569
import com.oracle.graal.python.runtime.exception.PException;
6670
import com.oracle.graal.python.runtime.exception.PythonErrorType;
6771
import com.oracle.graal.python.runtime.sequence.PSequence;
@@ -72,6 +76,7 @@
7276
import com.oracle.truffle.api.dsl.Fallback;
7377
import com.oracle.truffle.api.dsl.ImportStatic;
7478
import com.oracle.truffle.api.dsl.Specialization;
79+
import com.oracle.truffle.api.dsl.TypeSystemReference;
7580
import com.oracle.truffle.api.profiles.ConditionProfile;
7681

7782
public abstract class ListNodes {
@@ -179,41 +184,78 @@ public static FastConstructListNode create() {
179184
}
180185
}
181186

182-
public static class IndexNode extends PBaseNode {
187+
@TypeSystemReference(PythonArithmeticTypes.class)
188+
public abstract static class IndexNode extends PBaseNode {
183189
private static final String DEFAULT_ERROR_MSG = "list indices must be integers or slices, not %p";
184190
@Child LookupAndCallUnaryNode getIndexNode;
185191
private final CheckType checkType;
186192
private final String errorMessage;
187193

188-
private static enum CheckType {
194+
protected static enum CheckType {
189195
SUBSCRIPT,
190196
INTEGER,
191197
NUMBER;
192198
}
193199

194-
private IndexNode(String message, CheckType type) {
200+
protected IndexNode(String message, CheckType type) {
195201
checkType = type;
196202
getIndexNode = LookupAndCallUnaryNode.create(__INDEX__);
197203
errorMessage = message;
198204
}
199205

200206
public static IndexNode create(String message) {
201-
return new IndexNode(message, CheckType.SUBSCRIPT);
207+
return IndexNodeGen.create(message, CheckType.SUBSCRIPT);
202208
}
203209

204210
public static IndexNode create() {
205-
return new IndexNode(DEFAULT_ERROR_MSG, CheckType.SUBSCRIPT);
211+
return IndexNodeGen.create(DEFAULT_ERROR_MSG, CheckType.SUBSCRIPT);
206212
}
207213

208214
public static IndexNode createInteger(String msg) {
209-
return new IndexNode(msg, CheckType.INTEGER);
215+
return IndexNodeGen.create(msg, CheckType.INTEGER);
210216
}
211217

212218
public static IndexNode createNumber(String msg) {
213-
return new IndexNode(msg, CheckType.NUMBER);
219+
return IndexNodeGen.create(msg, CheckType.NUMBER);
214220
}
215221

216-
public Object execute(Object object) {
222+
public abstract Object execute(Object object);
223+
224+
protected boolean isSubscript() {
225+
return checkType == CheckType.SUBSCRIPT;
226+
}
227+
228+
protected boolean isNumber() {
229+
return checkType == CheckType.NUMBER;
230+
}
231+
232+
@Specialization
233+
long doLong(long slice) {
234+
return slice;
235+
}
236+
237+
@Specialization
238+
PInt doPInt(PInt slice) {
239+
return slice;
240+
}
241+
242+
@Specialization(guards = "isSubscript()")
243+
PSlice doSlice(PSlice slice) {
244+
return slice;
245+
}
246+
247+
@Specialization(guards = "isNumber()")
248+
float doFloat(float slice) {
249+
return slice;
250+
}
251+
252+
@Specialization(guards = "isNumber()")
253+
double doDouble(double slice) {
254+
return slice;
255+
}
256+
257+
@Fallback
258+
Object doGeneric(Object object) {
217259
Object idx = getIndexNode.executeObject(object);
218260
boolean valid = false;
219261
switch (checkType) {

0 commit comments

Comments
 (0)