Skip to content

Commit c7cdac5

Browse files
committed
Replace BinaryArithmetic.AddNode with PyNumberAddNode & introduce PySequenceConcatNode
1 parent 9ad077c commit c7cdac5

File tree

11 files changed

+315
-166
lines changed

11 files changed

+315
-166
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_sum.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -37,6 +37,8 @@
3737
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3838
# SOFTWARE.
3939

40+
from .util import assert_raises
41+
4042
def test_type_int():
4143
# does not fit into primitive 'long'
4244
i = 0xfffffffffffffffffffffffffffffffffff
@@ -70,3 +72,7 @@ def __next__(self):
7072

7173
def test_iterator():
7274
assert sum(SumTestClass()) == 45
75+
76+
def test_basics():
77+
assert sum([[1, 2], [3, 4]], []) == [1, 2, 3, 4]
78+
assert_raises(TypeError, sum, [1,2,3], None)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
import com.oracle.graal.python.lib.PyEvalGetGlobals;
167167
import com.oracle.graal.python.lib.PyEvalGetLocals;
168168
import com.oracle.graal.python.lib.PyMappingCheckNode;
169+
import com.oracle.graal.python.lib.PyNumberAddNode;
169170
import com.oracle.graal.python.lib.PyNumberAsSizeNode;
170171
import com.oracle.graal.python.lib.PyNumberIndexNode;
171172
import com.oracle.graal.python.lib.PyObjectAsciiNode;
@@ -213,7 +214,6 @@
213214
import com.oracle.graal.python.nodes.call.special.LookupSpecialMethodSlotNode;
214215
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
215216
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
216-
import com.oracle.graal.python.nodes.expression.BinaryArithmetic.AddNode;
217217
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
218218
import com.oracle.graal.python.nodes.expression.BinaryOpNode;
219219
import com.oracle.graal.python.nodes.expression.CoerceToBooleanNode;
@@ -2183,25 +2183,26 @@ Object ternary(VirtualFrame frame, Object x, Object y, Object z,
21832183
public abstract static class SumFunctionNode extends PythonBuiltinNode {
21842184

21852185
@Child private LookupAndCallUnaryNode next = LookupAndCallUnaryNode.create(SpecialMethodSlot.Next);
2186-
@Child private AddNode add = AddNode.create();
21872186

2188-
@Specialization(rewriteOn = UnexpectedResultException.class)
2187+
@Specialization(guards = "isNoValue(start)", rewriteOn = UnexpectedResultException.class)
21892188
int sumIntNone(VirtualFrame frame, Object arg1, @SuppressWarnings("unused") PNone start,
21902189
@Bind("this") Node inliningTarget,
2190+
@Shared @Cached PyNumberAddNode addNode,
21912191
@Shared @Cached IsBuiltinObjectProfile errorProfile,
21922192
@Shared("getIter") @Cached PyObjectGetIter getIter) throws UnexpectedResultException {
2193-
return sumIntInternal(frame, inliningTarget, arg1, 0, getIter, errorProfile);
2193+
return sumIntInternal(frame, inliningTarget, arg1, 0, addNode, getIter, errorProfile);
21942194
}
21952195

21962196
@Specialization(rewriteOn = UnexpectedResultException.class)
21972197
int sumIntInt(VirtualFrame frame, Object arg1, int start,
21982198
@Bind("this") Node inliningTarget,
2199+
@Shared @Cached PyNumberAddNode addNode,
21992200
@Shared @Cached IsBuiltinObjectProfile errorProfile,
22002201
@Shared("getIter") @Cached PyObjectGetIter getIter) throws UnexpectedResultException {
2201-
return sumIntInternal(frame, inliningTarget, arg1, start, getIter, errorProfile);
2202+
return sumIntInternal(frame, inliningTarget, arg1, start, addNode, getIter, errorProfile);
22022203
}
22032204

2204-
private int sumIntInternal(VirtualFrame frame, Node inliningTarget, Object arg1, int start, PyObjectGetIter getIter,
2205+
private int sumIntInternal(VirtualFrame frame, Node inliningTarget, Object arg1, int start, PyNumberAddNode add, PyObjectGetIter getIter,
22052206
IsBuiltinObjectProfile errorProfile) throws UnexpectedResultException {
22062207
Object iterator = getIter.execute(frame, inliningTarget, arg1);
22072208
int value = start;
@@ -2213,30 +2214,31 @@ private int sumIntInternal(VirtualFrame frame, Node inliningTarget, Object arg1,
22132214
e.expectStopIteration(inliningTarget, errorProfile);
22142215
return value;
22152216
} catch (UnexpectedResultException e) {
2216-
Object newValue = add.executeObject(frame, value, e.getResult());
2217-
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, newValue, errorProfile));
2217+
Object newValue = add.execute(frame, inliningTarget, value, e.getResult());
2218+
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, newValue, add, errorProfile));
22182219
}
22192220
try {
2220-
value = add.executeInt(frame, value, nextValue);
2221+
value = add.executeInt(frame, inliningTarget, value, nextValue);
22212222
} catch (UnexpectedResultException e) {
2222-
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, e.getResult(), errorProfile));
2223+
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, e.getResult(), add, errorProfile));
22232224
}
22242225
}
22252226
}
22262227

22272228
@Specialization(rewriteOn = UnexpectedResultException.class)
22282229
double sumDoubleDouble(VirtualFrame frame, Object arg1, double start,
22292230
@Bind("this") Node inliningTarget,
2231+
@Shared @Cached PyNumberAddNode addNode,
22302232
@Shared @Cached IsBuiltinObjectProfile errorProfile,
22312233
// dummy inline profile, so it can be @Shared, to optimize generated code:
22322234
@SuppressWarnings("unused") @Shared @Cached InlinedConditionProfile hasStart,
22332235
@Shared("getIter") @Cached PyObjectGetIter getIter,
22342236
// dummy raiseNode, so it can be @Shared, to optimize generated code:
22352237
@SuppressWarnings("unused") @Shared @Cached PRaiseNode.Lazy raiseNode) throws UnexpectedResultException {
2236-
return sumDoubleInternal(frame, inliningTarget, arg1, start, getIter, errorProfile);
2238+
return sumDoubleInternal(frame, inliningTarget, arg1, start, addNode, getIter, errorProfile);
22372239
}
22382240

2239-
private double sumDoubleInternal(VirtualFrame frame, Node inliningTarget, Object arg1, double start, PyObjectGetIter getIter,
2241+
private double sumDoubleInternal(VirtualFrame frame, Node inliningTarget, Object arg1, double start, PyNumberAddNode add, PyObjectGetIter getIter,
22402242
IsBuiltinObjectProfile errorProfile) throws UnexpectedResultException {
22412243
Object iterator = getIter.execute(frame, inliningTarget, arg1);
22422244
double value = start;
@@ -2248,20 +2250,21 @@ private double sumDoubleInternal(VirtualFrame frame, Node inliningTarget, Object
22482250
e.expectStopIteration(inliningTarget, errorProfile);
22492251
return value;
22502252
} catch (UnexpectedResultException e) {
2251-
Object newValue = add.executeObject(frame, value, e.getResult());
2252-
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, newValue, errorProfile));
2253+
Object newValue = add.execute(frame, inliningTarget, value, e.getResult());
2254+
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, newValue, add, errorProfile));
22532255
}
22542256
try {
2255-
value = add.executeDouble(frame, value, nextValue);
2257+
value = add.executeDouble(frame, inliningTarget, value, nextValue);
22562258
} catch (UnexpectedResultException e) {
2257-
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, e.getResult(), errorProfile));
2259+
throw new UnexpectedResultException(iterateGeneric(frame, inliningTarget, iterator, e.getResult(), add, errorProfile));
22582260
}
22592261
}
22602262
}
22612263

22622264
@Specialization(replaces = {"sumIntNone", "sumIntInt", "sumDoubleDouble"})
22632265
Object sum(VirtualFrame frame, Object arg1, Object start,
22642266
@Bind("this") Node inliningTarget,
2267+
@Shared @Cached PyNumberAddNode addNode,
22652268
@Shared @Cached IsBuiltinObjectProfile errorProfile,
22662269
@Shared("getIter") @Cached PyObjectGetIter getIter,
22672270
@Shared @Cached InlinedConditionProfile hasStart,
@@ -2274,10 +2277,10 @@ Object sum(VirtualFrame frame, Object arg1, Object start,
22742277
throw raiseNode.get(inliningTarget).raise(TypeError, ErrorMessages.CANT_SUM_BYTEARRAY);
22752278
}
22762279
Object iterator = getIter.execute(frame, inliningTarget, arg1);
2277-
return iterateGeneric(frame, inliningTarget, iterator, hasStart.profile(inliningTarget, start != NO_VALUE) ? start : 0, errorProfile);
2280+
return iterateGeneric(frame, inliningTarget, iterator, hasStart.profile(inliningTarget, start != NO_VALUE) ? start : 0, addNode, errorProfile);
22782281
}
22792282

2280-
private Object iterateGeneric(VirtualFrame frame, Node inliningTarget, Object iterator, Object start, IsBuiltinObjectProfile errorProfile) {
2283+
private Object iterateGeneric(VirtualFrame frame, Node inliningTarget, Object iterator, Object start, PyNumberAddNode add, IsBuiltinObjectProfile errorProfile) {
22812284
Object value = start;
22822285
while (true) {
22832286
Object nextValue;
@@ -2287,7 +2290,7 @@ private Object iterateGeneric(VirtualFrame frame, Node inliningTarget, Object it
22872290
e.expectStopIteration(inliningTarget, errorProfile);
22882291
return value;
22892292
}
2290-
value = add.executeObject(frame, value, nextValue);
2293+
value = add.execute(frame, inliningTarget, value, nextValue);
22912294
}
22922295
}
22932296
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextAbstractBuiltins.java

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
import com.oracle.graal.python.lib.GetNextNode;
101101
import com.oracle.graal.python.lib.PyIndexCheckNode;
102102
import com.oracle.graal.python.lib.PyIterCheckNode;
103+
import com.oracle.graal.python.lib.PyNumberAddNode;
103104
import com.oracle.graal.python.lib.PyNumberCheckNode;
104105
import com.oracle.graal.python.lib.PyNumberFloatNode;
105106
import com.oracle.graal.python.lib.PyNumberIndexNode;
@@ -108,6 +109,7 @@
108109
import com.oracle.graal.python.lib.PyObjectGetItem;
109110
import com.oracle.graal.python.lib.PyObjectLookupAttr;
110111
import com.oracle.graal.python.lib.PySequenceCheckNode;
112+
import com.oracle.graal.python.lib.PySequenceConcat;
111113
import com.oracle.graal.python.lib.PySequenceContainsNode;
112114
import com.oracle.graal.python.lib.PySequenceDelItemNode;
113115
import com.oracle.graal.python.lib.PySequenceGetItemNode;
@@ -601,25 +603,11 @@ protected MulNode createMul() {
601603

602604
@CApiBuiltin(ret = PyObjectTransfer, args = {PyObject, PyObject}, call = Direct)
603605
abstract static class PySequence_Concat extends CApiBinaryBuiltinNode {
604-
@Specialization(guards = {"checkNode.execute(inliningTarget, s1)", "checkNode.execute(inliningTarget, s1)"})
605-
Object concat(Object s1, Object s2,
606-
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
607-
@SuppressWarnings("unused") @Shared("check") @Cached PySequenceCheckNode checkNode,
608-
@Cached("createAdd()") BinaryArithmetic.AddNode addNode) {
609-
return addNode.executeObject(null, s1, s2);
610-
}
611-
612-
@Specialization(guards = {"!checkNode.execute(inliningTarget, s1) || checkNode.execute(inliningTarget, s2)"})
613-
static Object cantConcat(Object s1, @SuppressWarnings("unused") Object s2,
614-
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
615-
@SuppressWarnings("unused") @Shared("check") @Cached PySequenceCheckNode checkNode,
616-
@Cached PRaiseNode raiseNode) {
617-
throw raiseNode.raise(TypeError, ErrorMessages.OBJ_CANT_BE_CONCATENATED, s1);
618-
}
619-
620-
@NeverDefault
621-
protected BinaryArithmetic.AddNode createAdd() {
622-
return (BinaryArithmetic.AddNode) BinaryArithmetic.Add.create();
606+
@Specialization
607+
Object doIt(Object s1, Object s2,
608+
@Bind("this") Node inliningTarget,
609+
@Cached PySequenceConcat pySeqConcat) {
610+
return pySeqConcat.execute(null, inliningTarget, s1, s2);
623611
}
624612
}
625613

@@ -631,13 +619,13 @@ static Object concat(Object s1, Object s2,
631619
@Bind("this") Node inliningTarget,
632620
@Cached PyObjectLookupAttr lookupNode,
633621
@Cached CallNode callNode,
634-
@Cached("createAdd()") BinaryArithmetic.AddNode addNode,
622+
@Cached PyNumberAddNode addNode,
635623
@SuppressWarnings("unused") @Exclusive @Cached PySequenceCheckNode checkNode) {
636624
Object iaddCallable = lookupNode.execute(null, inliningTarget, s1, T___IADD__);
637625
if (iaddCallable != PNone.NO_VALUE) {
638626
return callNode.executeWithoutFrame(iaddCallable, s2);
639627
}
640-
return addNode.executeObject(null, s1, s2);
628+
return addNode.execute(null, inliningTarget, s1, s2);
641629
}
642630

643631
@Specialization(guards = "!checkNode.execute(inliningTarget, s1)", limit = "1")
@@ -647,10 +635,6 @@ static Object concat(Object s1, @SuppressWarnings("unused") Object s2,
647635
@Cached PRaiseNode raiseNode) {
648636
throw raiseNode.raise(TypeError, ErrorMessages.OBJ_CANT_BE_CONCATENATED, s1);
649637
}
650-
651-
protected BinaryArithmetic.AddNode createAdd() {
652-
return (BinaryArithmetic.AddNode) BinaryArithmetic.Add.create();
653-
}
654638
}
655639

656640
@CApiBuiltin(ret = Int, args = {PyObject, Py_ssize_t}, call = Ignored)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextBytesBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ static long fallback(Object obj,
149149
abstract static class PyTruffleBytes_Concat extends CApiBinaryBuiltinNode {
150150
@Specialization
151151
static Object concat(Object original, Object newPart,
152-
@Cached BytesCommonBuiltins.AddNode addNode) {
152+
@Cached BytesCommonBuiltins.ConcatNode addNode) {
153153
return addNode.execute(null, original, newPart);
154154
}
155155
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesCommonBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ static PBytesLike join(VirtualFrame frame, Object self, Object iterable,
288288

289289
@Slot(value = SlotKind.sq_concat, isComplex = true)
290290
@GenerateNodeFactory
291-
public abstract static class AddNode extends SqConcatBuiltinNode {
291+
public abstract static class ConcatNode extends SqConcatBuiltinNode {
292292

293293
@Specialization
294294
static PBytesLike add(PBytesLike self, PBytesLike other,

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/itertools/AccumulateBuiltins.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
import com.oracle.graal.python.builtins.objects.PNone;
5656
import com.oracle.graal.python.builtins.objects.list.PList;
5757
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
58+
import com.oracle.graal.python.lib.PyNumberAddNode;
5859
import com.oracle.graal.python.lib.PyObjectGetIter;
5960
import com.oracle.graal.python.nodes.call.CallNode;
60-
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
6161
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
6262
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
6363
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
@@ -97,7 +97,7 @@ public abstract static class NextNode extends PythonUnaryBuiltinNode {
9797
static Object next(VirtualFrame frame, PAccumulate self,
9898
@Bind("this") Node inliningTarget,
9999
@Cached BuiltinFunctions.NextNode nextNode,
100-
@Cached BinaryArithmetic.AddNode addNode,
100+
@Cached PyNumberAddNode addNode,
101101
@Cached CallNode callNode,
102102
@Cached InlinedBranchProfile hasInitialProfile,
103103
@Cached InlinedBranchProfile markerProfile,
@@ -115,7 +115,7 @@ static Object next(VirtualFrame frame, PAccumulate self,
115115
return value;
116116
}
117117
if (hasFuncProfile.profile(inliningTarget, self.getFunc() == null)) {
118-
self.setTotal(addNode.executeObject(frame, self.getTotal(), value));
118+
self.setTotal(addNode.execute(frame, inliningTarget, self.getTotal(), value));
119119
} else {
120120
self.setTotal(callNode.execute(frame, self.getFunc(), self.getTotal(), value));
121121
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/itertools/CountBuiltins.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -54,10 +54,10 @@
5454
import com.oracle.graal.python.builtins.PythonBuiltins;
5555
import com.oracle.graal.python.builtins.objects.str.StringUtils.SimpleTruffleStringFormatNode;
5656
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
57+
import com.oracle.graal.python.lib.PyNumberAddNode;
5758
import com.oracle.graal.python.lib.PyObjectGetAttr;
5859
import com.oracle.graal.python.lib.PyObjectReprAsObjectNode;
5960
import com.oracle.graal.python.lib.PyObjectTypeCheck;
60-
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
6161
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
6262
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6363
import com.oracle.graal.python.nodes.object.GetClassNode;
@@ -97,9 +97,10 @@ static Object iter(PCount self) {
9797
public abstract static class NextNode extends PythonUnaryBuiltinNode {
9898
@Specialization
9999
static Object next(VirtualFrame frame, PCount self,
100-
@Cached BinaryArithmetic.AddNode addNode) {
100+
@Bind("this") Node inliningTarget,
101+
@Cached PyNumberAddNode addNode) {
101102
Object cnt = self.getCnt();
102-
self.setCnt(addNode.executeObject(frame, self.getCnt(), self.getStep()));
103+
self.setCnt(addNode.execute(frame, inliningTarget, self.getCnt(), self.getStep()));
103104
return cnt;
104105
}
105106
}

0 commit comments

Comments
 (0)