|
64 | 64 | import com.oracle.graal.python.builtins.CoreFunctions;
|
65 | 65 | import com.oracle.graal.python.builtins.PythonBuiltins;
|
66 | 66 | import com.oracle.graal.python.builtins.modules.PosixModuleBuiltinsFactory.ConvertPathlikeObjectNodeGen;
|
| 67 | +import com.oracle.graal.python.builtins.modules.PosixModuleBuiltinsFactory.ReadFromChannelNodeGen; |
67 | 68 | import com.oracle.graal.python.builtins.modules.PosixModuleBuiltinsFactory.StatNodeFactory;
|
68 | 69 | import com.oracle.graal.python.builtins.objects.PNone;
|
69 | 70 | import com.oracle.graal.python.builtins.objects.bytes.OpaqueBytes;
|
@@ -836,77 +837,112 @@ public static WriteNode create() {
|
836 | 837 | }
|
837 | 838 | }
|
838 | 839 |
|
839 |
| - @Builtin(name = "read", fixedNumOfPositionalArgs = 2) |
840 |
| - @GenerateNodeFactory |
841 |
| - @TypeSystemReference(PythonArithmeticTypes.class) |
842 |
| - public abstract static class ReadNode extends PythonFileNode { |
843 |
| - private static final int MAX_READ = Integer.MAX_VALUE / 2; |
| 840 | + abstract static class ReadFromChannelNode extends PNodeWithContext { |
| 841 | + private final BranchProfile gotException = BranchProfile.create(); |
844 | 842 |
|
845 |
| - @Specialization(guards = "readOpaque(frame)") |
846 |
| - Object readOpaque(@SuppressWarnings("unused") VirtualFrame frame, int fd, @SuppressWarnings("unused") Object requestedSize, |
847 |
| - @Cached("createClassProfile()") ValueProfile channelClassProfile, |
848 |
| - @Cached("createBinaryProfile()") ConditionProfile instanceofProfile, |
849 |
| - @Cached("create()") BranchProfile gotException) { |
850 |
| - Channel channel = getResources().getFileChannel(fd, channelClassProfile); |
| 843 | + abstract ByteSequenceStorage execute(Channel channel, int size); |
| 844 | + |
| 845 | + @Specialization |
| 846 | + ByteSequenceStorage readSeekable(SeekableByteChannel channel, int size) { |
| 847 | + long availableSize; |
851 | 848 | try {
|
852 |
| - ByteSequenceStorage bytes = doRead(channel, MAX_READ, channelClassProfile, instanceofProfile); |
853 |
| - return new OpaqueBytes(Arrays.copyOf(bytes.getInternalByteArray(), bytes.length())); |
| 849 | + availableSize = channel.size() - channel.position(); |
854 | 850 | } catch (IOException e) {
|
855 | 851 | gotException.enter();
|
856 | 852 | throw raise(OSError, e.getMessage());
|
857 | 853 | }
|
| 854 | + if (availableSize > ReadNode.MAX_READ) { |
| 855 | + availableSize = ReadNode.MAX_READ; |
| 856 | + } |
| 857 | + int sz = (int) Math.min(availableSize, size); |
| 858 | + return readReadable(channel, sz); |
858 | 859 | }
|
859 | 860 |
|
860 |
| - @Specialization(guards = "!readOpaque(frame)") |
861 |
| - Object read(@SuppressWarnings("unused") VirtualFrame frame, int fd, long requestedSize, |
862 |
| - @Cached("createClassProfile()") ValueProfile channelClassProfile, |
863 |
| - @Cached("createBinaryProfile()") ConditionProfile instanceofProfile, |
864 |
| - @Cached("create()") BranchProfile gotException) { |
865 |
| - Channel channel = getResources().getFileChannel(fd, channelClassProfile); |
866 |
| - try { |
867 |
| - ByteSequenceStorage array = doRead(channel, (int) requestedSize, channelClassProfile, instanceofProfile); |
868 |
| - return factory().createBytes(array); |
869 |
| - } catch (IOException e) { |
870 |
| - gotException.enter(); |
871 |
| - throw raise(OSError, e.getMessage()); |
| 861 | + @Specialization |
| 862 | + ByteSequenceStorage readReadable(ReadableByteChannel channel, int size) { |
| 863 | + int sz = Math.min(size, ReadNode.MAX_READ); |
| 864 | + ByteBuffer dst = allocateBuffer(sz); |
| 865 | + int readSize = readIntoBuffer(channel, dst); |
| 866 | + byte[] array; |
| 867 | + if (readSize <= 0) { |
| 868 | + array = new byte[0]; |
| 869 | + readSize = 0; |
| 870 | + } else { |
| 871 | + array = getByteBufferArray(dst); |
872 | 872 | }
|
| 873 | + ByteSequenceStorage byteSequenceStorage = new ByteSequenceStorage(array); |
| 874 | + byteSequenceStorage.setNewLength(readSize); |
| 875 | + return byteSequenceStorage; |
873 | 876 | }
|
874 | 877 |
|
875 |
| - private ByteSequenceStorage doRead(Channel channel, int requestedSize, ValueProfile channelClassProfile, ConditionProfile instanceofProfile) throws IOException { |
876 |
| - if (instanceofProfile.profile(channelClassProfile.profile(channel) instanceof ReadableByteChannel)) { |
877 |
| - ReadableByteChannel readableChannel = (ReadableByteChannel) channel; |
878 |
| - int sz = Math.min(requestedSize, MAX_READ); |
879 |
| - ByteBuffer dst = allocateBuffer(sz); |
880 |
| - int readSize = readIntoBuffer(readableChannel, dst); |
881 |
| - byte[] array; |
882 |
| - if (readSize <= 0) { |
883 |
| - array = new byte[0]; |
884 |
| - readSize = 0; |
885 |
| - } else { |
886 |
| - array = getByteBufferArray(dst); |
887 |
| - } |
888 |
| - ByteSequenceStorage byteSequenceStorage = new ByteSequenceStorage(array); |
889 |
| - byteSequenceStorage.setNewLength(readSize); |
890 |
| - return byteSequenceStorage; |
| 878 | + @Specialization |
| 879 | + ByteSequenceStorage readGeneric(Channel channel, int size) { |
| 880 | + if (channel instanceof SeekableByteChannel) { |
| 881 | + return readSeekable((SeekableByteChannel) channel, size); |
| 882 | + } else if (channel instanceof ReadableByteChannel) { |
| 883 | + return readReadable((ReadableByteChannel) channel, size); |
| 884 | + } else { |
| 885 | + throw raise(OSError, "file not opened for reading"); |
891 | 886 | }
|
892 |
| - throw raise(OSError, "file not opened for reading"); |
893 | 887 | }
|
894 | 888 |
|
895 | 889 | @TruffleBoundary(allowInlining = true)
|
896 | 890 | private static byte[] getByteBufferArray(ByteBuffer dst) {
|
897 | 891 | return dst.array();
|
898 | 892 | }
|
899 | 893 |
|
900 |
| - @TruffleBoundary(allowInlining = true, transferToInterpreterOnException = false) |
901 |
| - private static int readIntoBuffer(ReadableByteChannel readableChannel, ByteBuffer dst) throws IOException { |
902 |
| - return readableChannel.read(dst); |
| 894 | + @TruffleBoundary(allowInlining = true) |
| 895 | + private int readIntoBuffer(ReadableByteChannel readableChannel, ByteBuffer dst) { |
| 896 | + try { |
| 897 | + return readableChannel.read(dst); |
| 898 | + } catch (IOException e) { |
| 899 | + gotException.enter(); |
| 900 | + throw raise(OSError, e.getMessage()); |
| 901 | + } |
903 | 902 | }
|
904 | 903 |
|
905 | 904 | @TruffleBoundary(allowInlining = true)
|
906 | 905 | private static ByteBuffer allocateBuffer(int sz) {
|
907 | 906 | return ByteBuffer.allocate(sz);
|
908 | 907 | }
|
909 | 908 |
|
| 909 | + public static ReadFromChannelNode create() { |
| 910 | + return ReadFromChannelNodeGen.create(); |
| 911 | + } |
| 912 | + } |
| 913 | + |
| 914 | + @Builtin(name = "read", fixedNumOfPositionalArgs = 2) |
| 915 | + @GenerateNodeFactory |
| 916 | + @TypeSystemReference(PythonArithmeticTypes.class) |
| 917 | + public abstract static class ReadNode extends PythonFileNode { |
| 918 | + private static final int MAX_READ = Integer.MAX_VALUE / 2; |
| 919 | + |
| 920 | + @Specialization(guards = "readOpaque(frame)") |
| 921 | + Object readOpaque(@SuppressWarnings("unused") VirtualFrame frame, int fd, @SuppressWarnings("unused") Object requestedSize, |
| 922 | + @Cached("createClassProfile()") ValueProfile channelClassProfile, |
| 923 | + @Cached("create()") ReadFromChannelNode readNode) { |
| 924 | + Channel channel = getResources().getFileChannel(fd, channelClassProfile); |
| 925 | + ByteSequenceStorage bytes = readNode.execute(channel, MAX_READ); |
| 926 | + return new OpaqueBytes(Arrays.copyOf(bytes.getInternalByteArray(), bytes.length())); |
| 927 | + } |
| 928 | + |
| 929 | + @Specialization(guards = "!readOpaque(frame)") |
| 930 | + Object read(@SuppressWarnings("unused") VirtualFrame frame, int fd, long requestedSize, |
| 931 | + @Cached("createClassProfile()") ValueProfile channelClassProfile, |
| 932 | + @Cached("create()") BranchProfile tooLarge, |
| 933 | + @Cached("create()") ReadFromChannelNode readNode) { |
| 934 | + int size; |
| 935 | + try { |
| 936 | + size = Math.toIntExact(requestedSize); |
| 937 | + } catch (ArithmeticException e) { |
| 938 | + tooLarge.enter(); |
| 939 | + size = MAX_READ; |
| 940 | + } |
| 941 | + Channel channel = getResources().getFileChannel(fd, channelClassProfile); |
| 942 | + ByteSequenceStorage array = readNode.execute(channel, size); |
| 943 | + return factory().createBytes(array); |
| 944 | + } |
| 945 | + |
910 | 946 | /**
|
911 | 947 | * @param frame - only used so the DSL sees this as a dynamic check
|
912 | 948 | */
|
|
0 commit comments