Skip to content

Commit bd937c5

Browse files
committed
Add not-connected checks for sockets
1 parent 9dc0318 commit bd937c5

File tree

1 file changed

+39
-20
lines changed
  • graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/socket

1 file changed

+39
-20
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/socket/SocketBuiltins.java

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ abstract static class AcceptNode extends PythonUnaryBuiltinNode {
124124
@Specialization
125125
@TruffleBoundary
126126
Object accept(PSocket socket) {
127+
if (socket.getServerSocket() == null) {
128+
throw raiseOSError(null, OSErrorEnum.EINVAL);
129+
}
127130
try {
128131
SocketChannel acceptSocket = SocketUtils.accept(this, socket);
129132
if (acceptSocket == null) {
@@ -338,6 +341,9 @@ Object listen(PSocket socket, PNone backlog) {
338341
abstract static class RecvNode extends PythonTernaryClinicBuiltinNode {
339342
@Specialization
340343
Object recv(VirtualFrame frame, PSocket socket, int bufsize, int flags) {
344+
if (socket.getSocket() == null) {
345+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
346+
}
341347
ByteBuffer readBytes = PythonUtils.allocateByteBuffer(bufsize);
342348
try {
343349
int length = SocketUtils.recv(this, socket, readBytes);
@@ -384,6 +390,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PMemoryView buffer, Object f
384390
@CachedLibrary(limit = "getCallSiteInlineCacheMaxDepth()") PythonObjectLibrary lib,
385391
@Cached("create(__LEN__)") LookupAndCallUnaryNode callLen,
386392
@Cached("create(__SETITEM__)") LookupAndCallTernaryNode setItem) {
393+
if (socket.getSocket() == null) {
394+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
395+
}
387396
int bufferLen = lib.asSizeWithState(callLen.executeObject(frame, buffer), PArguments.getThreadState(frame));
388397
byte[] targetBuffer = new byte[bufferLen];
389398
ByteBuffer byteBuffer = PythonUtils.wrapByteBuffer(targetBuffer);
@@ -410,6 +419,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PByteArray buffer, Object fl
410419
@Cached("createBinaryProfile()") ConditionProfile byteStorage,
411420
@Cached SequenceStorageNodes.LenNode lenNode,
412421
@Cached("createSetItem()") SequenceStorageNodes.SetItemNode setItem) {
422+
if (socket.getSocket() == null) {
423+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
424+
}
413425
SequenceStorage storage = buffer.getSequenceStorage();
414426
int bufferLen = lenNode.execute(storage);
415427
if (byteStorage.profile(storage instanceof ByteSequenceStorage)) {
@@ -470,17 +482,14 @@ Object send(VirtualFrame frame, PSocket socket, PBytes bytes, Object flags,
470482
@Cached SequenceStorageNodes.ToByteArrayNode toBytes) {
471483
// TODO: do not ignore flags
472484
if (socket.getSocket() == null) {
473-
throw raise(OSError);
474-
}
475-
476-
if (!socket.isOpen()) {
477-
throw raise(OSError);
485+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
478486
}
479-
480487
int written;
481488
ByteBuffer buffer = PythonUtils.wrapByteBuffer(toBytes.execute(bytes.getSequenceStorage()));
482489
try {
483490
written = SocketUtils.send(this, socket, buffer);
491+
} catch (NotYetConnectedException e) {
492+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
484493
} catch (IOException e) {
485494
throw raise(OSError);
486495
}
@@ -500,6 +509,9 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
500509
@Cached SequenceStorageNodes.ToByteArrayNode toBytes,
501510
@Cached ConditionProfile hasTimeoutProfile) {
502511
// TODO: do not ignore flags
512+
if (socket.getSocket() == null) {
513+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
514+
}
503515
ByteBuffer buffer = PythonUtils.wrapByteBuffer(toBytes.execute(bytes.getSequenceStorage()));
504516
long timeoutMillis = socket.getTimeoutInMilliseconds();
505517
TimeoutHelper timeoutHelper = null;
@@ -513,6 +525,8 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
513525
int written;
514526
try {
515527
written = SocketUtils.send(this, socket, buffer, timeoutMillis);
528+
} catch (NotYetConnectedException e) {
529+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
516530
} catch (IOException e) {
517531
throw raise(OSError);
518532
}
@@ -606,24 +620,29 @@ Object setTimeout(PSocket socket, Object secondsObj,
606620
@GenerateNodeFactory
607621
abstract static class shutdownNode extends PythonBinaryBuiltinNode {
608622
@Specialization
609-
@TruffleBoundary
610-
Object family(PSocket socket, int how) {
611-
if (socket.getSocket() != null) {
612-
try {
613-
if (how == 0 || how == 2) {
614-
socket.getSocket().shutdownInput();
615-
}
616-
if (how == 1 || how == 2) {
617-
socket.getSocket().shutdownOutput();
618-
}
619-
} catch (IOException e) {
620-
throw raise(OSError);
621-
}
622-
} else {
623+
Object family(VirtualFrame frame, PSocket socket, int how) {
624+
if (socket.getSocket() == null) {
625+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
626+
}
627+
try {
628+
shutdown(socket, how);
629+
} catch (NotYetConnectedException e) {
630+
throw raiseOSError(frame, OSErrorEnum.ENOTCONN);
631+
} catch (IOException e) {
623632
throw raise(OSError);
624633
}
625634
return PNone.NO_VALUE;
626635
}
636+
637+
@TruffleBoundary
638+
private static void shutdown(PSocket socket, int how) throws IOException {
639+
if (how == 0 || how == 2) {
640+
socket.getSocket().shutdownInput();
641+
}
642+
if (how == 1 || how == 2) {
643+
socket.getSocket().shutdownOutput();
644+
}
645+
}
627646
}
628647

629648
// family

0 commit comments

Comments
 (0)