Skip to content

Commit 86d13f4

Browse files
committed
avoid sharing BytesIO storage and make memoryio pass
1 parent 03bbe91 commit 86d13f4

File tree

5 files changed

+97
-26
lines changed

5 files changed

+97
-26
lines changed

graalpython/com.oracle.graal.python.test/src/tests/unittest_tags/test_memoryio.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_detach
1212
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_flags
1313
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_flush
14+
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_getbuffer
1415
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_getstate
1516
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_getvalue
1617
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_init
@@ -26,6 +27,7 @@
2627
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_readlines
2728
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_relative_seek
2829
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_seek
30+
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_setstate
2931
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_sizeof
3032
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_subclassing
3133
*graalpython.lib-python.3.test.test_memoryio.CBytesIOTest.test_tell
@@ -93,6 +95,7 @@
9395
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_detach
9496
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_flags
9597
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_flush
98+
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_getbuffer
9699
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_getvalue
97100
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_init
98101
*graalpython.lib-python.3.test.test_memoryio.PyBytesIOTest.test_instance_dict_leak

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/io/BytesIOBuiltins.java

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
import com.oracle.graal.python.util.PythonUtils;
122122
import com.oracle.truffle.api.CompilerDirectives;
123123
import com.oracle.truffle.api.dsl.Cached;
124+
import com.oracle.truffle.api.dsl.Cached.Shared;
124125
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
125126
import com.oracle.truffle.api.dsl.NodeFactory;
126127
import com.oracle.truffle.api.dsl.Specialization;
@@ -389,15 +390,18 @@ protected ArgumentClinicProvider getArgumentClinic() {
389390
}
390391

391392
@Specialization(guards = {"self.hasBuf()", "checkExports(self)"})
392-
static Object truncate(PBytesIO self, @SuppressWarnings("unused") PNone size,
393-
@Cached SequenceStorageNodes.SetLenNode setLenNode) {
394-
return truncate(self, self.getPos(), setLenNode);
393+
Object truncate(PBytesIO self, @SuppressWarnings("unused") PNone size,
394+
@Shared("i") @Cached SequenceStorageNodes.GetInternalArrayNode internalArray,
395+
@Shared("l") @Cached SequenceStorageNodes.SetLenNode setLenNode) {
396+
return truncate(self, self.getPos(), internalArray, setLenNode);
395397
}
396398

397399
@Specialization(guards = {"self.hasBuf()", "checkExports(self)", "size >= 0", "size < self.getStringSize()"})
398-
static Object truncate(PBytesIO self, int size,
399-
@Cached SequenceStorageNodes.SetLenNode setLenNode) {
400+
Object truncate(PBytesIO self, int size,
401+
@Shared("i") @Cached SequenceStorageNodes.GetInternalArrayNode internalArray,
402+
@Shared("l") @Cached SequenceStorageNodes.SetLenNode setLenNode) {
400403
self.setStringSize(size);
404+
resizeBuffer(self, size, internalArray, factory());
401405
setLenNode.execute(self.getBuf().getSequenceStorage(), size);
402406
return size;
403407
}
@@ -411,11 +415,12 @@ static Object same(@SuppressWarnings("unused") PBytesIO self, int size) {
411415
Object obj(VirtualFrame frame, PBytesIO self, Object arg,
412416
@Cached PyNumberAsSizeNode asSizeNode,
413417
@Cached PyNumberIndexNode indexNode,
414-
@Cached SequenceStorageNodes.SetLenNode setLenNode) {
418+
@Shared("i") @Cached SequenceStorageNodes.GetInternalArrayNode internalArray,
419+
@Shared("l") @Cached SequenceStorageNodes.SetLenNode setLenNode) {
415420
int size = asSizeNode.executeExact(frame, indexNode.execute(frame, arg), OverflowError);
416421
if (size >= 0) {
417422
if (size < self.getStringSize()) {
418-
return truncate(self, size, setLenNode);
423+
return truncate(self, size, internalArray, setLenNode);
419424
}
420425
return size;
421426
}
@@ -437,13 +442,45 @@ Object exportsError(@SuppressWarnings("unused") PBytesIO self, @SuppressWarnings
437442
}
438443
}
439444

445+
protected static void unshareBuffer(PBytesIO self, int size, byte[] buf,
446+
PythonObjectFactory factory) {
447+
/*- (mq) This method is only used when `self.buf.refcnt > 1`.
448+
`refcnt` is not available in our managed storage.
449+
Therefore, we always create a new storage in this case.
450+
*/
451+
byte[] newBuf = new byte[size];
452+
PythonUtils.arraycopy(buf, 0, newBuf, 0, self.getStringSize());
453+
self.setBuf(factory.createBytes(newBuf));
454+
}
455+
456+
protected static void unshareBuffer(PBytesIO self, int size,
457+
SequenceStorageNodes.GetInternalArrayNode internalArray,
458+
PythonObjectFactory factory) {
459+
byte[] buf = (byte[]) internalArray.execute(self.getBuf().getSequenceStorage());
460+
unshareBuffer(self, size, buf, factory);
461+
}
462+
463+
protected static void resizeBuffer(PBytesIO self, int size,
464+
SequenceStorageNodes.GetInternalArrayNode internalArray,
465+
PythonObjectFactory factory) {
466+
int alloc = self.getStringSize();
467+
if (size < alloc) {
468+
/* Within allocated size; quick exit */
469+
return;
470+
}
471+
// if (SHARED_BUF(self))
472+
unshareBuffer(self, size, internalArray, factory);
473+
// else resize self.buf
474+
}
475+
440476
@Builtin(name = WRITE, minNumOfPositionalArgs = 2)
441477
@GenerateNodeFactory
442478
abstract static class WriteNode extends ClosedCheckPythonBinaryBuiltinNode {
443479

444480
@Specialization(guards = {"self.hasBuf()", "checkExports(self)"})
445-
static Object doWrite(VirtualFrame frame, PBytesIO self, Object b,
481+
Object doWrite(VirtualFrame frame, PBytesIO self, Object b,
446482
@Cached BytesNodes.GetBuffer getBuffer,
483+
@Cached SequenceStorageNodes.GetInternalArrayNode internalArray,
447484
@Cached SequenceStorageNodes.EnsureCapacityNode ensureCapacityNode,
448485
@Cached SequenceStorageNodes.BytesMemcpyNode memcpyNode,
449486
@Cached SequenceStorageNodes.SetLenNode setLenNode) {
@@ -452,22 +489,25 @@ static Object doWrite(VirtualFrame frame, PBytesIO self, Object b,
452489
if (len == 0) {
453490
return 0;
454491
}
455-
write(frame, self, buf, ensureCapacityNode, memcpyNode, setLenNode);
492+
write(frame, self, buf, internalArray, ensureCapacityNode, memcpyNode, setLenNode, factory());
456493
return len;
457494
}
458495

459496
static void write(VirtualFrame frame, PBytesIO self, byte[] buf,
497+
SequenceStorageNodes.GetInternalArrayNode internalArray,
460498
SequenceStorageNodes.EnsureCapacityNode ensureCapacityNode,
461499
SequenceStorageNodes.BytesMemcpyNode memcpyNode,
462-
SequenceStorageNodes.SetLenNode setLenNode) {
500+
SequenceStorageNodes.SetLenNode setLenNode,
501+
PythonObjectFactory factory) {
463502
int len = buf.length;
464503
int pos = self.getPos();
465504
int size = self.getStringSize();
466505
int endpos = self.getPos() + len;
467506
ensureCapacityNode.execute(self.getBuf().getSequenceStorage(), endpos);
468507
if (pos > size) {
469-
byte[] nil = new byte[pos - size];
470-
memcpyNode.execute(frame, self.getBuf(), size, nil, 0, nil.length);
508+
resizeBuffer(self, endpos, internalArray, factory);
509+
} else { // if (SHARED_BUF(self))
510+
unshareBuffer(self, Math.max(endpos, size), internalArray, factory);
471511
}
472512
memcpyNode.execute(frame, self.getBuf(), pos, buf, 0, len);
473513
self.setPos(endpos);
@@ -621,11 +661,16 @@ Object closedError(PBytesIO self, int pos, int whence) {
621661
@GenerateNodeFactory
622662
abstract static class GetBufferNode extends ClosedCheckPythonUnaryBuiltinNode {
623663
@Specialization(guards = "self.hasBuf()")
624-
Object doit(PBytesIO self) {
664+
Object doit(PBytesIO self,
665+
@Cached SequenceStorageNodes.GetInternalArrayNode internalArray) {
666+
// if (SHARED_BUF(b))
667+
unshareBuffer(self, self.getStringSize(), internalArray, factory());
668+
// else do nothing to self.buf
669+
625670
PBytesIOBuffer buf = factory().createBytesIOBuf(PBytesIOBuf, self);
626671
int length = self.getStringSize();
627672
return factory().createMemoryView(getContext(), self.getManagedBuffer(), buf,
628-
length, true, 1, "B",
673+
length, false, 1, "B",
629674
1, null, 0, new int[]{length}, new int[]{1},
630675
null, PMemoryView.FLAG_C | PMemoryView.FLAG_FORTRAN);
631676
}
@@ -639,17 +684,28 @@ protected static boolean shouldCopy(PBytesIO self) {
639684
return self.getStringSize() <= 1 || self.getExports() > 0;
640685
}
641686

687+
protected static boolean shouldUnshare(PBytesIO self) {
688+
return self.getStringSize() != self.getBufCapacity();
689+
}
690+
642691
@Specialization(guards = {"self.hasBuf()", "shouldCopy(self)"})
643692
Object copy(PBytesIO self,
644693
@Cached SequenceStorageNodes.GetInternalByteArrayNode getBytes) {
645694
byte[] buf = getBytes.execute(self.getBuf().getSequenceStorage());
646695
return factory().createBytes(PythonUtils.arrayCopyOf(buf, self.getStringSize()));
647696
}
648697

649-
@Specialization(guards = {"self.hasBuf()", "!shouldCopy(self)"})
650-
static Object doit(PBytesIO self,
651-
@Cached SequenceStorageNodes.SetLenNode setLenNode) {
652-
setLenNode.execute(self.getBuf().getSequenceStorage(), self.getStringSize());
698+
@Specialization(guards = {"self.hasBuf()", "!shouldCopy(self)", "!shouldUnshare(self)"})
699+
static Object doit(PBytesIO self) {
700+
return self.getBuf();
701+
}
702+
703+
@Specialization(guards = {"self.hasBuf()", "!shouldCopy(self)", "shouldUnshare(self)"})
704+
Object unshare(PBytesIO self,
705+
@Cached SequenceStorageNodes.GetInternalArrayNode internalArray) {
706+
// if (SHARED_BUF(self))
707+
unshareBuffer(self, self.getStringSize(), internalArray, factory());
708+
// else resize self.buf
653709
return self.getBuf();
654710
}
655711
}
@@ -668,7 +724,7 @@ Object doit(VirtualFrame frame, PBytesIO self,
668724
}
669725
}
670726

671-
@Builtin(name = __SETSTATE__, minNumOfPositionalArgs = 1)
727+
@Builtin(name = __SETSTATE__, minNumOfPositionalArgs = 2)
672728
@GenerateNodeFactory
673729
abstract static class SetStateNode extends PythonBinaryBuiltinNode {
674730
@Specialization(guards = "checkExports(self)")

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/io/PBytesIO.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@
4343
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
4444
import com.oracle.graal.python.builtins.objects.memoryview.ManagedBuffer;
4545
import com.oracle.graal.python.builtins.objects.object.PythonBuiltinObject;
46+
import com.oracle.graal.python.runtime.sequence.storage.BasicSequenceStorage;
4647
import com.oracle.truffle.api.object.Shape;
4748

4849
public class PBytesIO extends PythonBuiltinObject {
4950
private PBytes buf;
5051
private int pos;
51-
private int string_size;
52+
private int stringSize;
5253
private final ManagedBuffer exports;
5354

5455
public PBytesIO(Object cls, Shape instanceShape) {
@@ -81,11 +82,16 @@ public void incPos(int n) {
8182
}
8283

8384
public int getStringSize() {
84-
return string_size;
85+
return stringSize;
8586
}
8687

8788
public void setStringSize(int size) {
88-
this.string_size = size;
89+
this.stringSize = size;
90+
}
91+
92+
public int getBufCapacity() {
93+
// Casting is safe as we only create/replace buf internally.
94+
return ((BasicSequenceStorage) buf.getSequenceStorage()).capacity();
8995
}
9096

9197
public ManagedBuffer getManagedBuffer() {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/BufferStorageNodes.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,14 @@ public abstract static class CopyBytesToBuffer extends Node {
373373

374374
@Specialization
375375
static void doBytes(byte[] src, int srcPos, PBytesLike dest, int destPos, int length,
376-
@Cached SequenceNodes.GetSequenceStorageNode getSequenceStorageNode,
377-
@Cached SequenceStorageNodes.CopyBytesToByteStorage copyTo) {
378-
copyTo.execute(src, srcPos, getSequenceStorageNode.execute(dest), destPos, length);
376+
@Shared("c") @Cached SequenceStorageNodes.CopyBytesToByteStorage copyTo) {
377+
copyTo.execute(src, srcPos, dest.getSequenceStorage(), destPos, length);
378+
}
379+
380+
@Specialization
381+
static void doBytesIOBuffer(byte[] src, int srcPos, PBytesIOBuffer dest, int destPos, int length,
382+
@Shared("c") @Cached SequenceStorageNodes.CopyBytesToByteStorage copyTo) {
383+
copyTo.execute(src, srcPos, dest.getSource().getBuf().getSequenceStorage(), destPos, length);
379384
}
380385

381386
@Specialization

graalpython/lib-python/3/test/test_memoryio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def buftype(s):
438438
ioclass = pyio.BytesIO
439439
EOF = b""
440440

441+
@support.impl_detail("finalization", graalvm=False)
441442
def test_getbuffer(self):
442443
memio = self.ioclass(b"1234567890")
443444
buf = memio.getbuffer()
@@ -458,7 +459,7 @@ def test_getbuffer(self):
458459
# After the buffer gets released, we can resize and close the BytesIO
459460
# again
460461
del buf
461-
support.gc_collect()
462+
support.gc_collect() # Truffle: after collecting buf, the export count should be 0.
462463
memio.truncate()
463464
memio.close()
464465
self.assertRaises(ValueError, memio.getbuffer)

0 commit comments

Comments
 (0)