Skip to content

Commit 0ece68f

Browse files
committed
Split 'SSN.ConcatNode' and reuse logic for 'extend' and 'append'.
1 parent afd0408 commit 0ece68f

File tree

9 files changed

+703
-398
lines changed

9 files changed

+703
-398
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_array.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,22 @@ def test_import():
5858

5959
def test_create():
6060
from array import array
61-
a = array('b', b'x'*10)
61+
a = array('b', b'x' * 10)
6262
assert str(a) == "array('b', [120, 120, 120, 120, 120, 120, 120, 120, 120, 120])"
63+
64+
65+
def test_add():
66+
from array import array
67+
a0 = array("b", b"hello")
68+
a1 = array("b", b"world")
69+
assert a0 + a1 == array("b", b"helloworld")
70+
71+
a0 = array("b", b"hello")
72+
a1 = array("l", b"abcdabcd")
73+
try:
74+
res = a0 + a1
75+
except TypeError:
76+
assert True
77+
else:
78+
assert False
79+

graalpython/com.oracle.graal.python.test/src/tests/test_bytes.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,52 @@ def test_setslice():
212212
# assign range
213213
b = bytearray(b"hellohellohello")
214214
b[5:10] = range(5)
215-
assert b == bytearray(b"hello\x04\x05\x06\x07\x08hello")
215+
assert b == bytearray(b'hello\x00\x01\x02\x03\x04hello')
216+
217+
b = bytearray(range(10))
218+
assert list(b) == list(range(10))
219+
220+
b[0:5] = bytearray([1, 1, 1, 1, 1])
221+
assert b == bytearray([1, 1, 1, 1, 1, 5, 6, 7, 8, 9])
222+
223+
# TODO: seq storage does not yet support deletion ...
224+
# del b[0:-5]
225+
# assert b == bytearray([5, 6, 7, 8, 9])
226+
b = bytearray([5, 6, 7, 8, 9])
227+
228+
# TODO: seq setSlice is broken ...
229+
# b[0:0] = bytearray([0, 1, 2, 3, 4])
230+
# assert b == bytearray(range(10))
231+
b = bytearray(range(10))
232+
233+
b[-7:-3] = bytearray([100, 101])
234+
assert b == bytearray([0, 1, 2, 100, 101, 7, 8, 9])
235+
236+
b[3:5] = [3, 4, 5, 6]
237+
assert b == bytearray(range(10))
238+
239+
b[3:0] = [42, 42, 42]
240+
assert b == bytearray([0, 1, 2, 42, 42, 42, 3, 4, 5, 6, 7, 8, 9])
241+
242+
b[3:] = b'foo'
243+
assert b == bytearray([0, 1, 2, 102, 111, 111])
244+
245+
b[:3] = memoryview(b'foo')
246+
assert b == bytearray([102, 111, 111, 102, 111, 111])
247+
248+
b[3:4] = []
249+
assert b == bytearray([102, 111, 111, 111, 111])
250+
251+
for elem in [5, -5, 0, int(10e20), 'str', 2.3,
252+
['a', 'b'], [b'a', b'b'], [[]]]:
253+
def assign():
254+
b[3:4] = elem
255+
assert_raises(TypeError, assign)
256+
257+
for elem in [[254, 255, 256], [-256, 9000]]:
258+
def assign():
259+
b[3:4] = elem
260+
assert_raises(ValueError, assign)
216261

217262

218263
def test_delitem():
@@ -295,53 +340,6 @@ def test_join():
295340
assert b"--".join([b"hello"]) == b"hello"
296341

297342

298-
def test_setslice():
299-
b = bytearray(range(10))
300-
assert list(b) == list(range(10))
301-
302-
b[0:5] = bytearray([1, 1, 1, 1, 1])
303-
assert b == bytearray([1, 1, 1, 1, 1, 5, 6, 7, 8, 9])
304-
305-
# TODO: seq storage does not yet support deletion ...
306-
# del b[0:-5]
307-
# assert b == bytearray([5, 6, 7, 8, 9])
308-
b = bytearray([5, 6, 7, 8, 9])
309-
310-
# TODO: seq setSlice is broken ...
311-
# b[0:0] = bytearray([0, 1, 2, 3, 4])
312-
# assert b == bytearray(range(10))
313-
b = bytearray(range(10))
314-
315-
b[-7:-3] = bytearray([100, 101])
316-
assert b == bytearray([0, 1, 2, 100, 101, 7, 8, 9])
317-
318-
b[3:5] = [3, 4, 5, 6]
319-
assert b == bytearray(range(10))
320-
321-
b[3:0] = [42, 42, 42]
322-
assert b == bytearray([0, 1, 2, 42, 42, 42, 3, 4, 5, 6, 7, 8, 9])
323-
324-
b[3:] = b'foo'
325-
assert b == bytearray([0, 1, 2, 102, 111, 111])
326-
327-
b[:3] = memoryview(b'foo')
328-
assert b == bytearray([102, 111, 111, 102, 111, 111])
329-
330-
b[3:4] = []
331-
assert b == bytearray([102, 111, 111, 111, 111])
332-
333-
for elem in [5, -5, 0, int(10e20), 'str', 2.3,
334-
['a', 'b'], [b'a', b'b'], [[]]]:
335-
def assign():
336-
b[3:4] = elem
337-
assert_raises(TypeError, assign)
338-
339-
for elem in [[254, 255, 256], [-256, 9000]]:
340-
def assign():
341-
b[3:4] = elem
342-
assert_raises(ValueError, assign)
343-
344-
345343
def test_concat():
346344
a = b'0'
347345
b = b'1'

graalpython/com.oracle.graal.python.test/src/tests/test_list.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def test_basic(self):
5050
x.extend(-y for y in x)
5151
self.assertEqual(x, [])
5252

53+
l = [0x1FFFFFFFF, 1, 2, 3, 4]
54+
l[0] = "hello"
55+
self.assertEqual(l, ["hello", 1, 2, 3, 4])
56+
5357
def test_truth(self):
5458
super().test_truth()
5559
self.assertTrue(not [])
@@ -481,6 +485,11 @@ def test_StopIteration(self):
481485
l.append(3)
482486
self.assertRaises(StopIteration, i.__next__)
483487

488+
def test_add(self):
489+
l1 = [1, 2, 3]
490+
l2 = ["a", "b", "c"]
491+
self.assertEqual(l1 + l2, [1, 2, 3, "a", "b", "c"])
492+
484493
def test_iadd_special(self):
485494
a = [1]
486495
a += (2, 3)
@@ -572,6 +581,14 @@ def __index__(self):
572581
ob = My(10)
573582
self.assertRaises(TypeError, l.__imul__, ob)
574583

584+
def test_append(self):
585+
l = []
586+
l.append(1)
587+
l.append(0x1FF)
588+
l.append(0x1FFFFFFFF)
589+
l.append("hello")
590+
self.assertEqual(l, [1, 0x1FF, 0x1FFFFFFFF, "hello"])
591+
575592

576593
class ListCompareTest(CompareTest):
577594

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ArrayModuleBuiltins.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
*/
2626
package com.oracle.graal.python.builtins.modules;
2727

28+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.OverflowError;
2829
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
2930
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
3031

@@ -36,13 +37,16 @@
3637
import com.oracle.graal.python.builtins.PythonBuiltins;
3738
import com.oracle.graal.python.builtins.objects.PNone;
3839
import com.oracle.graal.python.builtins.objects.array.PArray;
40+
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.CastToByteNode;
3941
import com.oracle.graal.python.builtins.objects.range.PRange;
4042
import com.oracle.graal.python.builtins.objects.type.PythonClass;
4143
import com.oracle.graal.python.nodes.control.GetIteratorNode;
4244
import com.oracle.graal.python.nodes.control.GetNextNode;
45+
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
4346
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
4447
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4548
import com.oracle.graal.python.runtime.exception.PException;
49+
import com.oracle.graal.python.runtime.exception.PythonErrorType;
4650
import com.oracle.graal.python.runtime.sequence.PSequence;
4751
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
4852
import com.oracle.truffle.api.dsl.Cached;
@@ -115,6 +119,7 @@ protected boolean isDoubleArray(String typeCode) {
115119

116120
@Specialization(guards = "isByteArray(typeCode)")
117121
PArray arrayByteInitializer(PythonClass cls, @SuppressWarnings("unused") String typeCode, PSequence initializer,
122+
@Cached("createCast()") CastToByteNode castToByteNode,
118123
@Cached("create()") GetIteratorNode getIterator,
119124
@Cached("create()") GetNextNode next,
120125
@Cached("createBinaryProfile()") ConditionProfile errorProfile) {
@@ -130,20 +135,7 @@ PArray arrayByteInitializer(PythonClass cls, @SuppressWarnings("unused") String
130135
e.expectStopIteration(getCore(), errorProfile);
131136
break;
132137
}
133-
134-
if (nextValue instanceof Byte) {
135-
byteArray[i++] = (byte) nextValue;
136-
}
137-
if (nextValue instanceof Integer) {
138-
int intValue = (int) nextValue;
139-
if (0 <= intValue && intValue <= 255) {
140-
byteArray[i++] = (byte) intValue;
141-
} else {
142-
throw raise(ValueError, "signed char is greater than maximum");
143-
}
144-
} else {
145-
throw raise(ValueError, "integer argument expected, got %p", nextValue);
146-
}
138+
byteArray[i++] = castToByteNode.execute(nextValue);
147139
}
148140

149141
return factory().createArray(cls, byteArray);
@@ -228,6 +220,13 @@ private PArray makeEmptyArray(PythonClass cls, char type) {
228220
}
229221
}
230222

223+
protected CastToByteNode createCast() {
224+
return CastToByteNode.create(val -> {
225+
throw raise(OverflowError, "signed char is greater than maximum");
226+
}, null);
227+
228+
}
229+
231230
@TruffleBoundary
232231
private void typeError(String typeCode, Object initializer) {
233232
throw raise(TypeError, "cannot use a %p to initialize an array with typecode '%s'", initializer, typeCode);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/ByteArrayBuiltins.java

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
import static com.oracle.graal.python.nodes.SpecialMethodNames.__DELITEM__;
3232
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
3333
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GETITEM__;
34+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GE__;
35+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GT__;
3436
import static com.oracle.graal.python.nodes.SpecialMethodNames.__INIT__;
3537
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ITER__;
3638
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LEN__;
37-
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LT__;
3839
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LE__;
39-
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GT__;
40-
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GE__;
40+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LT__;
4141
import static com.oracle.graal.python.nodes.SpecialMethodNames.__MUL__;
4242
import static com.oracle.graal.python.nodes.SpecialMethodNames.__REPR__;
4343
import static com.oracle.graal.python.nodes.SpecialMethodNames.__RMUL__;
@@ -57,6 +57,7 @@
5757
import com.oracle.graal.python.builtins.objects.PNotImplemented;
5858
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
5959
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GetItemNode;
60+
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NoGeneralizationNode;
6061
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NormalizeIndexNode;
6162
import com.oracle.graal.python.builtins.objects.ints.PInt;
6263
import com.oracle.graal.python.builtins.objects.memoryview.PMemoryView;
@@ -66,16 +67,13 @@
6667
import com.oracle.graal.python.nodes.PGuards;
6768
import com.oracle.graal.python.nodes.SpecialMethodNames;
6869
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
69-
import com.oracle.graal.python.nodes.control.GetIteratorNode;
70-
import com.oracle.graal.python.nodes.control.GetNextNode;
7170
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
7271
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
7372
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
7473
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
7574
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
7675
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
7776
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
78-
import com.oracle.graal.python.runtime.exception.PException;
7977
import com.oracle.graal.python.runtime.sequence.PSequence;
8078
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
8179
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
@@ -380,51 +378,18 @@ public PByteArray appendInt(PByteArray byteArray, byte arg) {
380378
// bytearray.extend(L)
381379
@Builtin(name = "extend", fixedNumOfPositionalArgs = 2)
382380
@GenerateNodeFactory
383-
public abstract static class ByteArrayExtendNode extends PythonBuiltinNode {
384-
385-
@Specialization(guards = {"isPSequenceWithStorage(source)"}, rewriteOn = {SequenceStoreException.class})
386-
public PNone extendSequenceStore(PByteArray byteArray, Object source) throws SequenceStoreException {
387-
SequenceStorage target = byteArray.getSequenceStorage();
388-
target.extend(((PSequence) source).getSequenceStorage());
389-
return PNone.NONE;
390-
}
381+
public abstract static class ByteArrayExtendNode extends PythonBinaryBuiltinNode {
391382

392-
@Specialization(guards = {"isPSequenceWithStorage(source)"})
393-
public PNone extendSequence(PByteArray byteArray, Object source) {
394-
SequenceStorage eSource = ((PSequence) source).getSequenceStorage();
395-
if (eSource.length() > 0) {
396-
SequenceStorage target = byteArray.getSequenceStorage();
397-
try {
398-
target.extend(eSource);
399-
} catch (SequenceStoreException e) {
400-
throw raise(ValueError, "byte must be in range(0, 256)");
401-
}
402-
}
383+
@Specialization
384+
public PNone doGeneric(PByteArray byteArray, Object source,
385+
@Cached("createExtend()") SequenceStorageNodes.ExtendNode extendNode) {
386+
SequenceStorage execute = extendNode.execute(byteArray.getSequenceStorage(), source);
387+
assert byteArray.getSequenceStorage() == execute;
403388
return PNone.NONE;
404389
}
405390

406-
@Specialization(guards = "!isPSequenceWithStorage(source)")
407-
public PNone extend(PByteArray byteArray, Object source,
408-
@Cached("create()") GetIteratorNode getIterator,
409-
@Cached("create()") GetNextNode next,
410-
@Cached("createBinaryProfile()") ConditionProfile errorProfile) {
411-
Object workSource = byteArray != source ? source : factory().createByteArray(((PSequence) source).getSequenceStorage().copy());
412-
Object iterator = getIterator.executeWith(workSource);
413-
while (true) {
414-
Object value;
415-
try {
416-
value = next.execute(iterator);
417-
} catch (PException e) {
418-
e.expectStopIteration(getCore(), errorProfile);
419-
return PNone.NONE;
420-
}
421-
422-
try {
423-
byteArray.append(value);
424-
} catch (SequenceStoreException e) {
425-
throw raise(ValueError, "byte must be in range(0, 256)");
426-
}
427-
}
391+
protected static SequenceStorageNodes.ExtendNode createExtend() {
392+
return SequenceStorageNodes.ExtendNode.create(() -> NoGeneralizationNode.create("byte must be in range(0, 256)"));
428393
}
429394

430395
protected boolean isPSequenceWithStorage(Object source) {

0 commit comments

Comments
 (0)