|
46 | 46 | import java.net.SocketAddress;
|
47 | 47 | import java.net.SocketException;
|
48 | 48 | import java.nio.ByteBuffer;
|
| 49 | +import java.nio.channels.NotYetConnectedException; |
49 | 50 | import java.nio.channels.ServerSocketChannel;
|
50 | 51 | import java.nio.channels.SocketChannel;
|
51 | 52 | import java.util.Arrays;
|
|
61 | 62 | import com.oracle.graal.python.builtins.objects.bytes.PBytes;
|
62 | 63 | import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
|
63 | 64 | import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
|
| 65 | +import com.oracle.graal.python.builtins.objects.exception.OSErrorEnum; |
64 | 66 | import com.oracle.graal.python.builtins.objects.tuple.PTuple;
|
65 | 67 | import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
|
66 | 68 | import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
|
67 | 69 | import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
|
68 | 70 | import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
|
69 | 71 | import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
|
| 72 | +import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage; |
| 73 | +import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage; |
70 | 74 | import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
|
71 | 75 | import com.oracle.truffle.api.dsl.Cached;
|
72 | 76 | import com.oracle.truffle.api.dsl.GenerateNodeFactory;
|
73 | 77 | import com.oracle.truffle.api.dsl.NodeFactory;
|
74 | 78 | import com.oracle.truffle.api.dsl.Specialization;
|
| 79 | +import com.oracle.truffle.api.frame.VirtualFrame; |
| 80 | +import com.oracle.truffle.api.profiles.ConditionProfile; |
75 | 81 |
|
76 | 82 | @CoreFunctions(extendClasses = PythonBuiltinClassType.PSocket)
|
77 | 83 | @SuppressWarnings("unused")
|
@@ -351,30 +357,49 @@ Object recvFrom(PSocket socket, int bufsize, PNone flags) {
|
351 | 357 | // recv_into(bufsize[, flags])
|
352 | 358 | @Builtin(name = "recv_into", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 3)
|
353 | 359 | @GenerateNodeFactory
|
354 |
| - abstract static class RecvIntoNode extends PythonBuiltinNode { |
355 |
| - @Specialization |
356 |
| - @TruffleBoundary |
357 |
| - Object recvInto(PSocket socket, PByteArray buffer) { |
358 |
| - byte[] targetBuffer = new byte[buffer.getSequenceStorage().length()]; |
359 |
| - |
360 |
| - int length = fillBuffer(socket, targetBuffer); |
361 |
| - // TODO: seems dirty, is there a better way to fill a byte array? |
362 |
| - |
363 |
| - for (int i = 0; i < length; i++) { |
364 |
| - buffer.getSequenceStorage().insertItem(i, targetBuffer[i]); |
| 360 | + abstract static class RecvIntoNode extends PythonTernaryBuiltinNode { |
| 361 | + @Specialization |
| 362 | + Object recvInto(VirtualFrame frame, PSocket socket, PByteArray buffer, Object flags, |
| 363 | + @Cached ConditionProfile byteStorage, |
| 364 | + @Cached SequenceStorageNodes.LenNode lenNode, |
| 365 | + @Cached SequenceStorageNodes.SetItemNode setItem) { |
| 366 | + SequenceStorage storage = buffer.getSequenceStorage(); |
| 367 | + int bufferLen = lenNode.execute(storage); |
| 368 | + if (byteStorage.profile(storage instanceof ByteSequenceStorage)) { |
| 369 | + ByteBuffer byteBuffer = ((ByteSequenceStorage) storage).getBufferView(); |
| 370 | + try { |
| 371 | + return fillBuffer(socket, byteBuffer); |
| 372 | + } catch (NotYetConnectedException e) { |
| 373 | + throw raiseOSError(frame, OSErrorEnum.ENOTCONN, e); |
| 374 | + } catch (IOException e) { |
| 375 | + throw raiseOSError(frame, OSErrorEnum.EBADF, e); |
| 376 | + } |
| 377 | + } else { |
| 378 | + byte[] targetBuffer = new byte[bufferLen]; |
| 379 | + ByteBuffer byteBuffer = ByteBuffer.wrap(targetBuffer); |
| 380 | + int length; |
| 381 | + try { |
| 382 | + length = fillBuffer(socket, byteBuffer); |
| 383 | + } catch (NotYetConnectedException e) { |
| 384 | + throw raiseOSError(frame, OSErrorEnum.ENOTCONN, e); |
| 385 | + } catch (IOException e) { |
| 386 | + throw raiseOSError(frame, OSErrorEnum.EBADF, e); |
| 387 | + } |
| 388 | + SequenceStorage newStorage = storage; |
| 389 | + for (int i = 0; i < length; i++) { |
| 390 | + newStorage = setItem.execute(newStorage, i, targetBuffer[i]); |
| 391 | + } |
| 392 | + if (newStorage != storage) { |
| 393 | + buffer.setSequenceStorage(newStorage); |
| 394 | + } |
| 395 | + return length; |
365 | 396 | }
|
366 |
| - |
367 |
| - return length; |
368 | 397 | }
|
369 | 398 |
|
370 |
| - int fillBuffer(PSocket socket, byte[] buffer) { |
371 |
| - ByteBuffer byteBuffer = ByteBuffer.wrap(buffer); |
| 399 | + @TruffleBoundary |
| 400 | + private static int fillBuffer(PSocket socket, ByteBuffer byteBuffer) throws IOException { |
372 | 401 | SocketChannel nativeSocket = socket.getSocket();
|
373 |
| - try { |
374 |
| - return nativeSocket.read(byteBuffer); |
375 |
| - } catch (IOException e) { |
376 |
| - throw raise(PythonBuiltinClassType.OSError); |
377 |
| - } |
| 402 | + return nativeSocket.read(byteBuffer); |
378 | 403 | }
|
379 | 404 | }
|
380 | 405 |
|
|
0 commit comments