|
68 | 68 | import com.oracle.graal.python.builtins.PythonBuiltinClassType;
|
69 | 69 | import com.oracle.graal.python.builtins.PythonBuiltins;
|
70 | 70 | import com.oracle.graal.python.builtins.objects.PNone;
|
| 71 | +import com.oracle.graal.python.builtins.objects.bytes.BytesNodes; |
71 | 72 | import com.oracle.graal.python.builtins.objects.bytes.PBytes;
|
72 | 73 | import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
|
73 | 74 | import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
|
|
76 | 77 | import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NormalizeIndexNode;
|
77 | 78 | import com.oracle.graal.python.builtins.objects.exception.OSErrorEnum;
|
78 | 79 | import com.oracle.graal.python.builtins.objects.ints.PInt;
|
| 80 | +import com.oracle.graal.python.builtins.objects.memoryview.PMemoryView; |
79 | 81 | import com.oracle.graal.python.builtins.objects.mmap.MMapBuiltinsFactory.InternalLenNodeGen;
|
80 | 82 | import com.oracle.graal.python.builtins.objects.slice.PSlice;
|
81 | 83 | import com.oracle.graal.python.builtins.objects.slice.PSlice.SliceInfo;
|
@@ -456,7 +458,7 @@ int readByte(PMMap self,
|
456 | 458 | @TypeSystemReference(PythonArithmeticTypes.class)
|
457 | 459 | abstract static class ReadNode extends PythonBuiltinNode {
|
458 | 460 |
|
459 |
| - @Specialization(guards = "!isNoValue(n)") |
| 461 | + @Specialization(guards = "isNoValue(n)") |
460 | 462 | PBytes read(PMMap self, @SuppressWarnings("unused") PNone n,
|
461 | 463 | @Cached("create()") ReadFromChannelNode readChannelNode) {
|
462 | 464 | ByteSequenceStorage res = readChannelNode.execute(self.getChannel(), ReadFromChannelNode.MAX_READ);
|
@@ -517,50 +519,94 @@ protected static SequenceStorageNodes.AppendNode createAppend() {
|
517 | 519 | }
|
518 | 520 | }
|
519 | 521 |
|
| 522 | + @Builtin(name = "write", fixedNumOfPositionalArgs = 2) |
| 523 | + @GenerateNodeFactory |
| 524 | + abstract static class WriteNode extends PythonBinaryBuiltinNode { |
| 525 | + |
| 526 | + @Specialization |
| 527 | + int writeBytesLike(PMMap self, PIBytesLike bytesLike, |
| 528 | + @Cached("create()") WriteToChannelNode writeNode, |
| 529 | + @Cached("create()") SequenceNodes.GetSequenceStorageNode getStorageNode) { |
| 530 | + SeekableByteChannel channel = self.getChannel(); |
| 531 | + return writeNode.execute(channel, getStorageNode.execute(bytesLike), Integer.MAX_VALUE); |
| 532 | + } |
| 533 | + |
| 534 | + @Specialization |
| 535 | + int writeMemoryview(PMMap self, PMemoryView memoryView, |
| 536 | + @Cached("create()") WriteToChannelNode writeNode, |
| 537 | + @Cached("create()") BytesNodes.ToBytesNode toBytesNode) { |
| 538 | + byte[] data = toBytesNode.execute(memoryView); |
| 539 | + return writeNode.execute(self.getChannel(), new ByteSequenceStorage(data), Integer.MAX_VALUE); |
| 540 | + } |
| 541 | + } |
| 542 | + |
520 | 543 | @Builtin(name = "seek", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3)
|
521 | 544 | @GenerateNodeFactory
|
522 | 545 | @TypeSystemReference(PythonArithmeticTypes.class)
|
523 | 546 | abstract static class SeekNode extends PythonBuiltinNode {
|
| 547 | + @Child private CastToIndexNode castToLongNode; |
| 548 | + |
| 549 | + private final BranchProfile errorProfile = BranchProfile.create(); |
524 | 550 |
|
525 | 551 | @Specialization(guards = "isNoValue(how)")
|
526 | 552 | Object seek(VirtualFrame frame, PMMap self, long dist, @SuppressWarnings("unused") PNone how) {
|
527 | 553 | return seek(frame, self, dist, 0);
|
528 | 554 | }
|
529 | 555 |
|
530 | 556 | @Specialization
|
531 |
| - Object seek(VirtualFrame frame, PMMap self, long dist, int how) { |
| 557 | + Object seek(VirtualFrame frame, PMMap self, long dist, Object how) { |
532 | 558 | try {
|
533 |
| - return doSeek(self, dist, how); |
| 559 | + SeekableByteChannel channel = self.getChannel(); |
| 560 | + long size; |
| 561 | + if (self.getLength() == 0) { |
| 562 | + size = channel.size() - self.getOffset(); |
| 563 | + } else { |
| 564 | + size = self.getLength(); |
| 565 | + } |
| 566 | + long where; |
| 567 | + int ihow = castToInt(how); |
| 568 | + switch (ihow) { |
| 569 | + case 0: /* relative to start */ |
| 570 | + where = dist; |
| 571 | + break; |
| 572 | + case 1: /* relative to current position */ |
| 573 | + where = position(channel) + dist; |
| 574 | + break; |
| 575 | + case 2: /* relative to end */ |
| 576 | + where = size + dist; |
| 577 | + break; |
| 578 | + default: |
| 579 | + errorProfile.enter(); |
| 580 | + throw raise(PythonBuiltinClassType.ValueError, "unknown seek type"); |
| 581 | + } |
| 582 | + if (where > size || where < 0) { |
| 583 | + errorProfile.enter(); |
| 584 | + throw raise(PythonBuiltinClassType.ValueError, "seek out of range"); |
| 585 | + } |
| 586 | + doSeek(channel, where); |
| 587 | + return PNone.NONE; |
534 | 588 | } catch (IOException e) {
|
| 589 | + errorProfile.enter(); |
535 | 590 | throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
|
536 | 591 | }
|
537 | 592 | }
|
538 | 593 |
|
539 |
| - @TruffleBoundary |
540 |
| - private Object doSeek(PMMap self, long dist, int how) throws IOException { |
541 |
| - SeekableByteChannel channel = self.getChannel(); |
542 |
| - long where; |
543 |
| - switch (how) { |
544 |
| - case 0: /* relative to start */ |
545 |
| - where = dist; |
546 |
| - break; |
547 |
| - case 1: /* relative to current position */ |
548 |
| - where = channel.position() + dist; |
549 |
| - break; |
550 |
| - case 2: /* relative to end */ |
551 |
| - long size; |
552 |
| - if (self.getLength() == 0) { |
553 |
| - size = channel.size() - self.getOffset(); |
554 |
| - } else { |
555 |
| - size = self.getLength(); |
556 |
| - } |
557 |
| - where = size + dist; |
558 |
| - break; |
559 |
| - default: |
560 |
| - throw raise(PythonBuiltinClassType.ValueError, "unknown seek type"); |
561 |
| - } |
| 594 | + @TruffleBoundary(allowInlining = true) |
| 595 | + private static long position(SeekableByteChannel channel) throws IOException { |
| 596 | + return channel.position(); |
| 597 | + } |
| 598 | + |
| 599 | + @TruffleBoundary(allowInlining = true) |
| 600 | + private static void doSeek(SeekableByteChannel channel, long where) throws IOException { |
562 | 601 | channel.position(where);
|
563 |
| - return PNone.NONE; |
| 602 | + } |
| 603 | + |
| 604 | + private int castToInt(Object val) { |
| 605 | + if (castToLongNode == null) { |
| 606 | + CompilerDirectives.transferToInterpreterAndInvalidate(); |
| 607 | + castToLongNode = insert(CastToIndexNode.create()); |
| 608 | + } |
| 609 | + return castToLongNode.execute(val); |
564 | 610 | }
|
565 | 611 | }
|
566 | 612 |
|
|
0 commit comments