Skip to content

Commit d53d780

Browse files
committed
Implement 'mmap.write'.
1 parent a04afa8 commit d53d780

File tree

1 file changed

+73
-27
lines changed
  • graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/mmap

1 file changed

+73
-27
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/mmap/MMapBuiltins.java

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
6969
import com.oracle.graal.python.builtins.PythonBuiltins;
7070
import com.oracle.graal.python.builtins.objects.PNone;
71+
import com.oracle.graal.python.builtins.objects.bytes.BytesNodes;
7172
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
7273
import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
7374
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
@@ -76,6 +77,7 @@
7677
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NormalizeIndexNode;
7778
import com.oracle.graal.python.builtins.objects.exception.OSErrorEnum;
7879
import com.oracle.graal.python.builtins.objects.ints.PInt;
80+
import com.oracle.graal.python.builtins.objects.memoryview.PMemoryView;
7981
import com.oracle.graal.python.builtins.objects.mmap.MMapBuiltinsFactory.InternalLenNodeGen;
8082
import com.oracle.graal.python.builtins.objects.slice.PSlice;
8183
import com.oracle.graal.python.builtins.objects.slice.PSlice.SliceInfo;
@@ -456,7 +458,7 @@ int readByte(PMMap self,
456458
@TypeSystemReference(PythonArithmeticTypes.class)
457459
abstract static class ReadNode extends PythonBuiltinNode {
458460

459-
@Specialization(guards = "!isNoValue(n)")
461+
@Specialization(guards = "isNoValue(n)")
460462
PBytes read(PMMap self, @SuppressWarnings("unused") PNone n,
461463
@Cached("create()") ReadFromChannelNode readChannelNode) {
462464
ByteSequenceStorage res = readChannelNode.execute(self.getChannel(), ReadFromChannelNode.MAX_READ);
@@ -517,50 +519,94 @@ protected static SequenceStorageNodes.AppendNode createAppend() {
517519
}
518520
}
519521

522+
@Builtin(name = "write", fixedNumOfPositionalArgs = 2)
523+
@GenerateNodeFactory
524+
abstract static class WriteNode extends PythonBinaryBuiltinNode {
525+
526+
@Specialization
527+
int writeBytesLike(PMMap self, PIBytesLike bytesLike,
528+
@Cached("create()") WriteToChannelNode writeNode,
529+
@Cached("create()") SequenceNodes.GetSequenceStorageNode getStorageNode) {
530+
SeekableByteChannel channel = self.getChannel();
531+
return writeNode.execute(channel, getStorageNode.execute(bytesLike), Integer.MAX_VALUE);
532+
}
533+
534+
@Specialization
535+
int writeMemoryview(PMMap self, PMemoryView memoryView,
536+
@Cached("create()") WriteToChannelNode writeNode,
537+
@Cached("create()") BytesNodes.ToBytesNode toBytesNode) {
538+
byte[] data = toBytesNode.execute(memoryView);
539+
return writeNode.execute(self.getChannel(), new ByteSequenceStorage(data), Integer.MAX_VALUE);
540+
}
541+
}
542+
520543
@Builtin(name = "seek", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3)
521544
@GenerateNodeFactory
522545
@TypeSystemReference(PythonArithmeticTypes.class)
523546
abstract static class SeekNode extends PythonBuiltinNode {
547+
@Child private CastToIndexNode castToLongNode;
548+
549+
private final BranchProfile errorProfile = BranchProfile.create();
524550

525551
@Specialization(guards = "isNoValue(how)")
526552
Object seek(VirtualFrame frame, PMMap self, long dist, @SuppressWarnings("unused") PNone how) {
527553
return seek(frame, self, dist, 0);
528554
}
529555

530556
@Specialization
531-
Object seek(VirtualFrame frame, PMMap self, long dist, int how) {
557+
Object seek(VirtualFrame frame, PMMap self, long dist, Object how) {
532558
try {
533-
return doSeek(self, dist, how);
559+
SeekableByteChannel channel = self.getChannel();
560+
long size;
561+
if (self.getLength() == 0) {
562+
size = channel.size() - self.getOffset();
563+
} else {
564+
size = self.getLength();
565+
}
566+
long where;
567+
int ihow = castToInt(how);
568+
switch (ihow) {
569+
case 0: /* relative to start */
570+
where = dist;
571+
break;
572+
case 1: /* relative to current position */
573+
where = position(channel) + dist;
574+
break;
575+
case 2: /* relative to end */
576+
where = size + dist;
577+
break;
578+
default:
579+
errorProfile.enter();
580+
throw raise(PythonBuiltinClassType.ValueError, "unknown seek type");
581+
}
582+
if (where > size || where < 0) {
583+
errorProfile.enter();
584+
throw raise(PythonBuiltinClassType.ValueError, "seek out of range");
585+
}
586+
doSeek(channel, where);
587+
return PNone.NONE;
534588
} catch (IOException e) {
589+
errorProfile.enter();
535590
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
536591
}
537592
}
538593

539-
@TruffleBoundary
540-
private Object doSeek(PMMap self, long dist, int how) throws IOException {
541-
SeekableByteChannel channel = self.getChannel();
542-
long where;
543-
switch (how) {
544-
case 0: /* relative to start */
545-
where = dist;
546-
break;
547-
case 1: /* relative to current position */
548-
where = channel.position() + dist;
549-
break;
550-
case 2: /* relative to end */
551-
long size;
552-
if (self.getLength() == 0) {
553-
size = channel.size() - self.getOffset();
554-
} else {
555-
size = self.getLength();
556-
}
557-
where = size + dist;
558-
break;
559-
default:
560-
throw raise(PythonBuiltinClassType.ValueError, "unknown seek type");
561-
}
594+
@TruffleBoundary(allowInlining = true)
595+
private static long position(SeekableByteChannel channel) throws IOException {
596+
return channel.position();
597+
}
598+
599+
@TruffleBoundary(allowInlining = true)
600+
private static void doSeek(SeekableByteChannel channel, long where) throws IOException {
562601
channel.position(where);
563-
return PNone.NONE;
602+
}
603+
604+
private int castToInt(Object val) {
605+
if (castToLongNode == null) {
606+
CompilerDirectives.transferToInterpreterAndInvalidate();
607+
castToLongNode = insert(CastToIndexNode.create());
608+
}
609+
return castToLongNode.execute(val);
564610
}
565611
}
566612

0 commit comments

Comments
 (0)