Skip to content

Commit 59c6c6c

Browse files
committed
[GR-13242] Bytearray += operation has to return the same object.
PullRequest: graalpython/348
2 parents 441e2e0 + b8d0c21 commit 59c6c6c

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_bytes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,14 @@ class BytesSplitTest(BaseTestSplit, unittest.TestCase):
610610
class ByteArraySplitTest(BaseTestSplit, unittest.TestCase):
611611
type2test = bytearray
612612

613+
614+
def test_eq_add_bytearray():
615+
b1 = bytearray(b'')
616+
b2 = b1
617+
b1 += bytearray(b'Hello')
618+
assert b1 is b2
619+
620+
613621
def test_add_mv_to_bytes():
614622
b = b'hello '
615623
mv = memoryview(b'world')

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/ByteArrayBuiltins.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import static com.oracle.graal.python.builtins.objects.slice.PSlice.MISSING_INDEX;
3030
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ADD__;
31+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__IADD__;
3132
import static com.oracle.graal.python.nodes.SpecialMethodNames.__BOOL__;
3233
import static com.oracle.graal.python.nodes.SpecialMethodNames.__DELITEM__;
3334
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
@@ -265,6 +266,45 @@ public Object add(Object self, Object other) {
265266
}
266267
}
267268

269+
@Builtin(name = __IADD__, fixedNumOfPositionalArgs = 2)
270+
@GenerateNodeFactory
271+
public abstract static class IAddNode extends PythonBinaryBuiltinNode {
272+
@Specialization
273+
public PByteArray add(PByteArray self, PIBytesLike other,
274+
@Cached("create()") SequenceStorageNodes.ConcatNode concatNode) {
275+
SequenceStorage res = concatNode.execute(self.getSequenceStorage(), other.getSequenceStorage());
276+
updateSequenceStorage(self, res);
277+
return self;
278+
}
279+
280+
@Specialization
281+
public PByteArray add(PByteArray self, PMemoryView other,
282+
@Cached("create(TOBYTES)") LookupAndCallUnaryNode toBytesNode,
283+
@Cached("createBinaryProfile()") ConditionProfile isBytesProfile,
284+
@Cached("create()") SequenceStorageNodes.ConcatNode concatNode) {
285+
286+
Object bytesObj = toBytesNode.executeObject(other);
287+
if (isBytesProfile.profile(bytesObj instanceof PBytes)) {
288+
SequenceStorage res = concatNode.execute(self.getSequenceStorage(), ((PBytes) bytesObj).getSequenceStorage());
289+
updateSequenceStorage(self, res);
290+
return self;
291+
}
292+
throw raise(SystemError, "could not get bytes of memoryview");
293+
}
294+
295+
@SuppressWarnings("unused")
296+
@Fallback
297+
public Object add(Object self, Object other) {
298+
throw raise(TypeError, "can't concat bytearray to %p", other);
299+
}
300+
301+
private static void updateSequenceStorage(PByteArray array, SequenceStorage s) {
302+
if (array.getSequenceStorage() != s) {
303+
array.setSequenceStorage(s);
304+
}
305+
}
306+
}
307+
268308
@Builtin(name = __MUL__, fixedNumOfPositionalArgs = 2)
269309
@GenerateNodeFactory
270310
public abstract static class MulNode extends PythonBuiltinNode {

0 commit comments

Comments
 (0)