Skip to content

Commit 6599715

Browse files
committed
refactor calling __INDEX__ from various list builtins through a common IndexNode
1 parent a2c22e0 commit 6599715

File tree

3 files changed

+132
-95
lines changed

3 files changed

+132
-95
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathGuards.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ public static boolean fitLong(double value) {
5151
}
5252

5353
public static boolean isNumber(Object value) {
54-
return value instanceof Integer || value instanceof Long || value instanceof Float || value instanceof Double || value instanceof PInt || value instanceof PFloat || value instanceof Boolean;
54+
return isInteger(value) || value instanceof Float || value instanceof Double || value instanceof PFloat;
55+
}
56+
57+
public static boolean isInteger(Object value) {
58+
return value instanceof Integer || value instanceof Long || value instanceof PInt || value instanceof Boolean;
5559
}
5660
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/list/ListBuiltins.java

Lines changed: 59 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151

5252
import java.math.BigInteger;
5353
import java.util.List;
54-
import java.util.function.Supplier;
5554

5655
import com.oracle.graal.python.builtins.Builtin;
5756
import com.oracle.graal.python.builtins.CoreFunctions;
@@ -71,9 +70,9 @@
7170
import com.oracle.graal.python.nodes.PGuards;
7271
import com.oracle.graal.python.nodes.PNode;
7372
import com.oracle.graal.python.nodes.builtins.ListNodes;
73+
import com.oracle.graal.python.nodes.builtins.ListNodes.IndexNode;
7474
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
7575
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
76-
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode.NoAttributeHandler;
7776
import com.oracle.graal.python.nodes.control.GetIteratorNode;
7877
import com.oracle.graal.python.nodes.control.GetNextNode;
7978
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
@@ -243,17 +242,21 @@ protected PNone doPListSlice(PList self, PSlice slice) {
243242
return PNone.NONE;
244243
}
245244

246-
@SuppressWarnings("unused")
247-
@Fallback
248-
protected Object doGeneric(Object self, Object idx) {
249-
if (!isValidIndexType(idx)) {
250-
throw raise(TypeError, "list indices must be integers or slices, not %p", idx);
251-
}
252-
throw raise(TypeError, "descriptor '__delitem__' requires a 'list' object but received a '%p'", idx);
245+
protected static DelItemNode create() {
246+
return ListBuiltinsFactory.DelItemNodeFactory.create(new PNode[0]);
253247
}
254248

255-
protected boolean isValidIndexType(Object idx) {
256-
return PGuards.isInteger(idx) || idx instanceof PSlice;
249+
@Specialization
250+
protected Object doObjectIndex(PList self, Object objectIdx,
251+
@Cached("create()") IndexNode getIndexNode,
252+
@Cached("create()") DelItemNode getRecursiveNode) {
253+
return getRecursiveNode.execute(self, getIndexNode.execute(objectIdx));
254+
}
255+
256+
@SuppressWarnings("unused")
257+
@Fallback
258+
protected Object doGeneric(Object self, Object objectIdx) {
259+
throw raise(TypeError, "descriptor '__delitem__' requires a 'list' object but received a '%p'", self);
257260
}
258261
}
259262

@@ -329,34 +332,22 @@ protected Object doPListSlice(PList self, PSlice slice) {
329332
return self.getSlice(factory(), slice);
330333
}
331334

332-
protected static final Supplier<NoAttributeHandler> NO_INDEX = () -> new NoAttributeHandler() {
333-
@Override
334-
public Object execute(Object receiver) {
335-
throw raise(TypeError, "list indices must be integers or slices, not %p", receiver);
336-
}
337-
};
335+
protected static GetItemNode create() {
336+
return ListBuiltinsFactory.GetItemNodeFactory.create(new PNode[0]);
337+
}
338338

339339
@Specialization
340340
protected Object doObjectIndex(PList self, Object objectIdx,
341-
@Cached("create(__INDEX__, NO_INDEX)") LookupAndCallUnaryNode getIndexNode,
342-
@Cached("create(__GETITEM__)") LookupAndCallBinaryNode getRecursiveNode) {
343-
Object idx = getIndexNode.executeObject(objectIdx);
344-
if (isValidIndexType(idx)) {
345-
return getRecursiveNode.executeObject(self, idx);
346-
} else {
347-
throw raise(TypeError, "list indices must be integers or slices, not %p", idx);
348-
}
341+
@Cached("create()") IndexNode getIndexNode,
342+
@Cached("create()") GetItemNode getRecursiveNode) {
343+
return getRecursiveNode.execute(self, getIndexNode.execute(objectIdx));
349344
}
350345

351346
@SuppressWarnings("unused")
352347
@Fallback
353348
protected Object doGeneric(Object self, Object objectIdx) {
354349
throw raise(TypeError, "descriptor '__getitem__' requires a 'list' object but received a '%p'", self);
355350
}
356-
357-
protected boolean isValidIndexType(Object idx) {
358-
return PGuards.isInteger(idx) || idx instanceof PSlice || idx instanceof PInt;
359-
}
360351
}
361352

362353
@Builtin(name = __SETITEM__, fixedNumOfArguments = 3)
@@ -427,17 +418,21 @@ public Object doPList(PList list, long idx, Object value,
427418
return PNone.NONE;
428419
}
429420

430-
@SuppressWarnings("unused")
431-
@Fallback
432-
protected Object doGeneric(Object self, Object idx, Object value) {
433-
if (!isValidIndexType(idx)) {
434-
throw raise(TypeError, "list indices must be integers or slices, not %p", idx);
435-
}
436-
throw raise(TypeError, "descriptor '__setitem__' requires a 'list' object but received a '%p'", idx);
421+
protected static SetItemNode create() {
422+
return ListBuiltinsFactory.SetItemNodeFactory.create(new PNode[0]);
437423
}
438424

439-
protected boolean isValidIndexType(Object idx) {
440-
return PGuards.isInteger(idx) || idx instanceof PSlice;
425+
@Specialization
426+
protected Object doObjectIndex(PList self, Object objectIdx, Object value,
427+
@Cached("create()") IndexNode getIndexNode,
428+
@Cached("create()") SetItemNode getRecursiveNode) {
429+
return getRecursiveNode.execute(self, getIndexNode.execute(objectIdx), value);
430+
}
431+
432+
@SuppressWarnings("unused")
433+
@Fallback
434+
protected Object doGeneric(Object self, Object objectIdx, Object value) {
435+
throw raise(TypeError, "descriptor '__setitem__' requires a 'list' object but received a '%p'", self);
441436
}
442437
}
443438

@@ -568,6 +563,7 @@ protected boolean isPSequenceWithStorage(Object source) {
568563
@Builtin(name = "insert", fixedNumOfArguments = 3)
569564
@GenerateNodeFactory
570565
public abstract static class ListInsertNode extends PythonBuiltinNode {
566+
protected static final String ERROR_MSG = "'%p' object cannot be interpreted as an integer";
571567

572568
public abstract PNone execute(PList list, Object index, Object value);
573569

@@ -623,12 +619,9 @@ public PNone insertPIntIndex(PList list, PInt index, Object value,
623619

624620
@Specialization(guards = {"!isIntegerOrPInt(i)"})
625621
public PNone insert(PList list, Object i, Object value,
626-
@Cached("create(__INDEX__)") LookupAndCallUnaryNode indexNode,
622+
@Cached("createInteger(ERROR_MSG)") IndexNode indexNode,
627623
@Cached("createListInsertNode()") ListInsertNode insertNode) {
628-
Object indexValue = indexNode.executeObject(i);
629-
if (PNone.NO_VALUE == indexValue) {
630-
throw raise(TypeError, "'%p' object cannot be interpreted as an integer", i);
631-
}
624+
Object indexValue = indexNode.execute(i);
632625
return insertNode.execute(list, indexValue, value);
633626
}
634627

@@ -845,7 +838,7 @@ private Object popOnIndex(PList list, int index, ConditionProfile cp) {
845838
@ImportStatic(MathGuards.class)
846839
@GenerateNodeFactory
847840
public abstract static class ListIndexNode extends PythonBuiltinNode {
848-
private final static String ERROR_TYPE_MESSAGE = "slice indices must be integers or have an __index__ method";
841+
protected final static String ERROR_TYPE_MESSAGE = "slice indices must be integers or have an __index__ method";
849842

850843
public abstract int execute(Object arg1, Object arg2, Object arg3, Object arg4);
851844

@@ -939,42 +932,27 @@ int indexOD(PTuple self, Object value, Object start, double end) {
939932

940933
@Specialization(guards = "!isNumber(start)")
941934
int indexO(PTuple self, Object value, Object start, PNone end,
942-
@Cached("create(__INDEX__)") LookupAndCallUnaryNode startNode,
935+
@Cached("createNumber(ERROR_TYPE_MESSAGE)") IndexNode startNode,
943936
@Cached("createIndexNode()") ListIndexNode indexNode) {
944-
945-
Object startValue = startNode.executeObject(start);
946-
if (PNone.NO_VALUE == startValue || !MathGuards.isNumber(startValue)) {
947-
throw raise(TypeError, ERROR_TYPE_MESSAGE);
948-
}
937+
Object startValue = startNode.execute(start);
949938
return indexNode.execute(self, value, startValue, end);
950939
}
951940

952941
@Specialization(guards = {"!isNumber(end)",})
953942
int indexLO(PTuple self, Object value, long start, Object end,
954-
@Cached("create(__INDEX__)") LookupAndCallUnaryNode endNode,
943+
@Cached("createNumber(ERROR_TYPE_MESSAGE)") IndexNode endNode,
955944
@Cached("createIndexNode()") ListIndexNode indexNode) {
956-
957-
Object endValue = endNode.executeObject(end);
958-
if (PNone.NO_VALUE == endValue || !MathGuards.isNumber(endValue)) {
959-
throw raise(TypeError, ERROR_TYPE_MESSAGE);
960-
}
945+
Object endValue = endNode.execute(end);
961946
return indexNode.execute(self, value, start, endValue);
962947
}
963948

964949
@Specialization(guards = {"!isNumber(start) || !isNumber(end)",})
965950
int indexOO(PTuple self, Object value, Object start, Object end,
966-
@Cached("create(__INDEX__)") LookupAndCallUnaryNode startNode,
967-
@Cached("create(__INDEX__)") LookupAndCallUnaryNode endNode,
951+
@Cached("createNumber(ERROR_TYPE_MESSAGE)") IndexNode startNode,
952+
@Cached("createNumber(ERROR_TYPE_MESSAGE)") IndexNode endNode,
968953
@Cached("createIndexNode()") ListIndexNode indexNode) {
969-
970-
Object startValue = startNode.executeObject(start);
971-
if (PNone.NO_VALUE == startValue || !MathGuards.isNumber(startValue)) {
972-
throw raise(TypeError, ERROR_TYPE_MESSAGE);
973-
}
974-
Object endValue = endNode.executeObject(end);
975-
if (PNone.NO_VALUE == endValue || !MathGuards.isNumber(endValue)) {
976-
throw raise(TypeError, ERROR_TYPE_MESSAGE);
977-
}
954+
Object startValue = startNode.execute(start);
955+
Object endValue = endNode.execute(end);
978956
return indexNode.execute(self, value, startValue, endValue);
979957
}
980958

@@ -1232,6 +1210,7 @@ Object doGeneric(Object left, Object right) {
12321210
@Builtin(name = __IMUL__, fixedNumOfArguments = 2)
12331211
@GenerateNodeFactory
12341212
abstract static class IMulNode extends PythonBuiltinNode {
1213+
protected static final String ERROR_MSG = "can't multiply sequence by non-int of type '%p'";
12351214

12361215
public abstract PList execute(PList list, Object value);
12371216

@@ -1439,36 +1418,29 @@ PList doObjectPInt(PList list, PInt right) {
14391418

14401419
@Specialization(guards = {"!isInt(right)"})
14411420
Object doGeneric(PList list, Object right,
1442-
@Cached("create(__INDEX__)") LookupAndCallUnaryNode dispatchIndex,
1421+
@Cached("createInteger(ERROR_MSG)") IndexNode dispatchIndex,
14431422
@Cached("createIMulNode()") IMulNode imulNode) {
1444-
Object index = dispatchIndex.executeObject(right);
1445-
if (index != PNone.NO_VALUE) {
1446-
int iIndex;
1447-
try {
1448-
iIndex = convertToInt(index);
1449-
} catch (ArithmeticException e) {
1450-
throw raise(OverflowError, "cannot fit '%p' into an index-sized integer", index);
1451-
}
1452-
1453-
return imulNode.execute(list, iIndex);
1423+
Object index = dispatchIndex.execute(right);
1424+
int iIndex;
1425+
try {
1426+
iIndex = convertToInt(index);
1427+
} catch (ArithmeticException e) {
1428+
throw raise(OverflowError, "cannot fit '%p' into an index-sized integer", index);
14541429
}
1455-
throw raise(TypeError, "can't multiply sequence by non-int of type '%p'", right);
1430+
return imulNode.execute(list, iIndex);
14561431
}
14571432

1458-
private int convertToInt(Object value) throws ArithmeticException {
1433+
private static int convertToInt(Object value) throws ArithmeticException {
14591434
if (value instanceof Integer) {
14601435
return (Integer) value;
1461-
}
1462-
if (value instanceof Boolean) {
1436+
} else if (value instanceof Boolean) {
14631437
return (Boolean) value ? 0 : 1;
1464-
}
1465-
if (value instanceof Long) {
1438+
} else if (value instanceof Long) {
14661439
return PInt.intValueExact((Long) value);
1467-
}
1468-
if (value instanceof PInt) {
1440+
} else {
1441+
assert value instanceof PInt;
14691442
return ((PInt) value).intValueExact();
14701443
}
1471-
throw raise(TypeError, "can't multiply sequence by non-int of type '%p'", value);
14721444
}
14731445

14741446
protected IMulNode createIMulNode() {

0 commit comments

Comments
 (0)