121
121
import com .oracle .graal .python .util .PythonUtils ;
122
122
import com .oracle .truffle .api .CompilerDirectives ;
123
123
import com .oracle .truffle .api .dsl .Cached ;
124
+ import com .oracle .truffle .api .dsl .Cached .Shared ;
124
125
import com .oracle .truffle .api .dsl .GenerateNodeFactory ;
125
126
import com .oracle .truffle .api .dsl .NodeFactory ;
126
127
import com .oracle .truffle .api .dsl .Specialization ;
@@ -389,15 +390,18 @@ protected ArgumentClinicProvider getArgumentClinic() {
389
390
}
390
391
391
392
@ Specialization (guards = {"self.hasBuf()" , "checkExports(self)" })
392
- static Object truncate (PBytesIO self , @ SuppressWarnings ("unused" ) PNone size ,
393
- @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
394
- return truncate (self , self .getPos (), setLenNode );
393
+ Object truncate (PBytesIO self , @ SuppressWarnings ("unused" ) PNone size ,
394
+ @ Shared ("i" ) @ Cached SequenceStorageNodes .GetInternalArrayNode internalArray ,
395
+ @ Shared ("l" ) @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
396
+ return truncate (self , self .getPos (), internalArray , setLenNode );
395
397
}
396
398
397
399
@ Specialization (guards = {"self.hasBuf()" , "checkExports(self)" , "size >= 0" , "size < self.getStringSize()" })
398
- static Object truncate (PBytesIO self , int size ,
399
- @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
400
+ Object truncate (PBytesIO self , int size ,
401
+ @ Shared ("i" ) @ Cached SequenceStorageNodes .GetInternalArrayNode internalArray ,
402
+ @ Shared ("l" ) @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
400
403
self .setStringSize (size );
404
+ resizeBuffer (self , size , internalArray , factory ());
401
405
setLenNode .execute (self .getBuf ().getSequenceStorage (), size );
402
406
return size ;
403
407
}
@@ -411,11 +415,12 @@ static Object same(@SuppressWarnings("unused") PBytesIO self, int size) {
411
415
Object obj (VirtualFrame frame , PBytesIO self , Object arg ,
412
416
@ Cached PyNumberAsSizeNode asSizeNode ,
413
417
@ Cached PyNumberIndexNode indexNode ,
414
- @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
418
+ @ Shared ("i" ) @ Cached SequenceStorageNodes .GetInternalArrayNode internalArray ,
419
+ @ Shared ("l" ) @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
415
420
int size = asSizeNode .executeExact (frame , indexNode .execute (frame , arg ), OverflowError );
416
421
if (size >= 0 ) {
417
422
if (size < self .getStringSize ()) {
418
- return truncate (self , size , setLenNode );
423
+ return truncate (self , size , internalArray , setLenNode );
419
424
}
420
425
return size ;
421
426
}
@@ -437,13 +442,45 @@ Object exportsError(@SuppressWarnings("unused") PBytesIO self, @SuppressWarnings
437
442
}
438
443
}
439
444
445
+ protected static void unshareBuffer (PBytesIO self , int size , byte [] buf ,
446
+ PythonObjectFactory factory ) {
447
+ /*- (mq) This method is only used when `self.buf.refcnt > 1`.
448
+ `refcnt` is not available in our managed storage.
449
+ Therefore, we always create a new storage in this case.
450
+ */
451
+ byte [] newBuf = new byte [size ];
452
+ PythonUtils .arraycopy (buf , 0 , newBuf , 0 , self .getStringSize ());
453
+ self .setBuf (factory .createBytes (newBuf ));
454
+ }
455
+
456
+ protected static void unshareBuffer (PBytesIO self , int size ,
457
+ SequenceStorageNodes .GetInternalArrayNode internalArray ,
458
+ PythonObjectFactory factory ) {
459
+ byte [] buf = (byte []) internalArray .execute (self .getBuf ().getSequenceStorage ());
460
+ unshareBuffer (self , size , buf , factory );
461
+ }
462
+
463
+ protected static void resizeBuffer (PBytesIO self , int size ,
464
+ SequenceStorageNodes .GetInternalArrayNode internalArray ,
465
+ PythonObjectFactory factory ) {
466
+ int alloc = self .getStringSize ();
467
+ if (size < alloc ) {
468
+ /* Within allocated size; quick exit */
469
+ return ;
470
+ }
471
+ // if (SHARED_BUF(self))
472
+ unshareBuffer (self , size , internalArray , factory );
473
+ // else resize self.buf
474
+ }
475
+
440
476
@ Builtin (name = WRITE , minNumOfPositionalArgs = 2 )
441
477
@ GenerateNodeFactory
442
478
abstract static class WriteNode extends ClosedCheckPythonBinaryBuiltinNode {
443
479
444
480
@ Specialization (guards = {"self.hasBuf()" , "checkExports(self)" })
445
- static Object doWrite (VirtualFrame frame , PBytesIO self , Object b ,
481
+ Object doWrite (VirtualFrame frame , PBytesIO self , Object b ,
446
482
@ Cached BytesNodes .GetBuffer getBuffer ,
483
+ @ Cached SequenceStorageNodes .GetInternalArrayNode internalArray ,
447
484
@ Cached SequenceStorageNodes .EnsureCapacityNode ensureCapacityNode ,
448
485
@ Cached SequenceStorageNodes .BytesMemcpyNode memcpyNode ,
449
486
@ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
@@ -452,22 +489,25 @@ static Object doWrite(VirtualFrame frame, PBytesIO self, Object b,
452
489
if (len == 0 ) {
453
490
return 0 ;
454
491
}
455
- write (frame , self , buf , ensureCapacityNode , memcpyNode , setLenNode );
492
+ write (frame , self , buf , internalArray , ensureCapacityNode , memcpyNode , setLenNode , factory () );
456
493
return len ;
457
494
}
458
495
459
496
static void write (VirtualFrame frame , PBytesIO self , byte [] buf ,
497
+ SequenceStorageNodes .GetInternalArrayNode internalArray ,
460
498
SequenceStorageNodes .EnsureCapacityNode ensureCapacityNode ,
461
499
SequenceStorageNodes .BytesMemcpyNode memcpyNode ,
462
- SequenceStorageNodes .SetLenNode setLenNode ) {
500
+ SequenceStorageNodes .SetLenNode setLenNode ,
501
+ PythonObjectFactory factory ) {
463
502
int len = buf .length ;
464
503
int pos = self .getPos ();
465
504
int size = self .getStringSize ();
466
505
int endpos = self .getPos () + len ;
467
506
ensureCapacityNode .execute (self .getBuf ().getSequenceStorage (), endpos );
468
507
if (pos > size ) {
469
- byte [] nil = new byte [pos - size ];
470
- memcpyNode .execute (frame , self .getBuf (), size , nil , 0 , nil .length );
508
+ resizeBuffer (self , endpos , internalArray , factory );
509
+ } else { // if (SHARED_BUF(self))
510
+ unshareBuffer (self , Math .max (endpos , size ), internalArray , factory );
471
511
}
472
512
memcpyNode .execute (frame , self .getBuf (), pos , buf , 0 , len );
473
513
self .setPos (endpos );
@@ -621,11 +661,16 @@ Object closedError(PBytesIO self, int pos, int whence) {
621
661
@ GenerateNodeFactory
622
662
abstract static class GetBufferNode extends ClosedCheckPythonUnaryBuiltinNode {
623
663
@ Specialization (guards = "self.hasBuf()" )
624
- Object doit (PBytesIO self ) {
664
+ Object doit (PBytesIO self ,
665
+ @ Cached SequenceStorageNodes .GetInternalArrayNode internalArray ) {
666
+ // if (SHARED_BUF(b))
667
+ unshareBuffer (self , self .getStringSize (), internalArray , factory ());
668
+ // else do nothing to self.buf
669
+
625
670
PBytesIOBuffer buf = factory ().createBytesIOBuf (PBytesIOBuf , self );
626
671
int length = self .getStringSize ();
627
672
return factory ().createMemoryView (getContext (), self .getManagedBuffer (), buf ,
628
- length , true , 1 , "B" ,
673
+ length , false , 1 , "B" ,
629
674
1 , null , 0 , new int []{length }, new int []{1 },
630
675
null , PMemoryView .FLAG_C | PMemoryView .FLAG_FORTRAN );
631
676
}
@@ -639,17 +684,28 @@ protected static boolean shouldCopy(PBytesIO self) {
639
684
return self .getStringSize () <= 1 || self .getExports () > 0 ;
640
685
}
641
686
687
+ protected static boolean shouldUnshare (PBytesIO self ) {
688
+ return self .getStringSize () != self .getBufCapacity ();
689
+ }
690
+
642
691
@ Specialization (guards = {"self.hasBuf()" , "shouldCopy(self)" })
643
692
Object copy (PBytesIO self ,
644
693
@ Cached SequenceStorageNodes .GetInternalByteArrayNode getBytes ) {
645
694
byte [] buf = getBytes .execute (self .getBuf ().getSequenceStorage ());
646
695
return factory ().createBytes (PythonUtils .arrayCopyOf (buf , self .getStringSize ()));
647
696
}
648
697
649
- @ Specialization (guards = {"self.hasBuf()" , "!shouldCopy(self)" })
650
- static Object doit (PBytesIO self ,
651
- @ Cached SequenceStorageNodes .SetLenNode setLenNode ) {
652
- setLenNode .execute (self .getBuf ().getSequenceStorage (), self .getStringSize ());
698
+ @ Specialization (guards = {"self.hasBuf()" , "!shouldCopy(self)" , "!shouldUnshare(self)" })
699
+ static Object doit (PBytesIO self ) {
700
+ return self .getBuf ();
701
+ }
702
+
703
+ @ Specialization (guards = {"self.hasBuf()" , "!shouldCopy(self)" , "shouldUnshare(self)" })
704
+ Object unshare (PBytesIO self ,
705
+ @ Cached SequenceStorageNodes .GetInternalArrayNode internalArray ) {
706
+ // if (SHARED_BUF(self))
707
+ unshareBuffer (self , self .getStringSize (), internalArray , factory ());
708
+ // else resize self.buf
653
709
return self .getBuf ();
654
710
}
655
711
}
@@ -668,7 +724,7 @@ Object doit(VirtualFrame frame, PBytesIO self,
668
724
}
669
725
}
670
726
671
- @ Builtin (name = __SETSTATE__ , minNumOfPositionalArgs = 1 )
727
+ @ Builtin (name = __SETSTATE__ , minNumOfPositionalArgs = 2 )
672
728
@ GenerateNodeFactory
673
729
abstract static class SetStateNode extends PythonBinaryBuiltinNode {
674
730
@ Specialization (guards = "checkExports(self)" )
0 commit comments