@@ -124,6 +124,9 @@ abstract static class AcceptNode extends PythonUnaryBuiltinNode {
124
124
@ Specialization
125
125
@ TruffleBoundary
126
126
Object accept (PSocket socket ) {
127
+ if (socket .getServerSocket () == null ) {
128
+ throw raiseOSError (null , OSErrorEnum .EINVAL );
129
+ }
127
130
try {
128
131
SocketChannel acceptSocket = SocketUtils .accept (this , socket );
129
132
if (acceptSocket == null ) {
@@ -338,6 +341,9 @@ Object listen(PSocket socket, PNone backlog) {
338
341
abstract static class RecvNode extends PythonTernaryClinicBuiltinNode {
339
342
@ Specialization
340
343
Object recv (VirtualFrame frame , PSocket socket , int bufsize , int flags ) {
344
+ if (socket .getSocket () == null ) {
345
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
346
+ }
341
347
ByteBuffer readBytes = PythonUtils .allocateByteBuffer (bufsize );
342
348
try {
343
349
int length = SocketUtils .recv (this , socket , readBytes );
@@ -384,6 +390,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PMemoryView buffer, Object f
384
390
@ CachedLibrary (limit = "getCallSiteInlineCacheMaxDepth()" ) PythonObjectLibrary lib ,
385
391
@ Cached ("create(__LEN__)" ) LookupAndCallUnaryNode callLen ,
386
392
@ Cached ("create(__SETITEM__)" ) LookupAndCallTernaryNode setItem ) {
393
+ if (socket .getSocket () == null ) {
394
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
395
+ }
387
396
int bufferLen = lib .asSizeWithState (callLen .executeObject (frame , buffer ), PArguments .getThreadState (frame ));
388
397
byte [] targetBuffer = new byte [bufferLen ];
389
398
ByteBuffer byteBuffer = PythonUtils .wrapByteBuffer (targetBuffer );
@@ -410,6 +419,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PByteArray buffer, Object fl
410
419
@ Cached ("createBinaryProfile()" ) ConditionProfile byteStorage ,
411
420
@ Cached SequenceStorageNodes .LenNode lenNode ,
412
421
@ Cached ("createSetItem()" ) SequenceStorageNodes .SetItemNode setItem ) {
422
+ if (socket .getSocket () == null ) {
423
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
424
+ }
413
425
SequenceStorage storage = buffer .getSequenceStorage ();
414
426
int bufferLen = lenNode .execute (storage );
415
427
if (byteStorage .profile (storage instanceof ByteSequenceStorage )) {
@@ -470,17 +482,14 @@ Object send(VirtualFrame frame, PSocket socket, PBytes bytes, Object flags,
470
482
@ Cached SequenceStorageNodes .ToByteArrayNode toBytes ) {
471
483
// TODO: do not ignore flags
472
484
if (socket .getSocket () == null ) {
473
- throw raise (OSError );
474
- }
475
-
476
- if (!socket .isOpen ()) {
477
- throw raise (OSError );
485
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
478
486
}
479
-
480
487
int written ;
481
488
ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
482
489
try {
483
490
written = SocketUtils .send (this , socket , buffer );
491
+ } catch (NotYetConnectedException e ) {
492
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
484
493
} catch (IOException e ) {
485
494
throw raise (OSError );
486
495
}
@@ -500,6 +509,9 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
500
509
@ Cached SequenceStorageNodes .ToByteArrayNode toBytes ,
501
510
@ Cached ConditionProfile hasTimeoutProfile ) {
502
511
// TODO: do not ignore flags
512
+ if (socket .getSocket () == null ) {
513
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
514
+ }
503
515
ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
504
516
long timeoutMillis = socket .getTimeoutInMilliseconds ();
505
517
TimeoutHelper timeoutHelper = null ;
@@ -513,6 +525,8 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
513
525
int written ;
514
526
try {
515
527
written = SocketUtils .send (this , socket , buffer , timeoutMillis );
528
+ } catch (NotYetConnectedException e ) {
529
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
516
530
} catch (IOException e ) {
517
531
throw raise (OSError );
518
532
}
@@ -606,24 +620,29 @@ Object setTimeout(PSocket socket, Object secondsObj,
606
620
@ GenerateNodeFactory
607
621
abstract static class shutdownNode extends PythonBinaryBuiltinNode {
608
622
@ Specialization
609
- @ TruffleBoundary
610
- Object family (PSocket socket , int how ) {
611
- if (socket .getSocket () != null ) {
612
- try {
613
- if (how == 0 || how == 2 ) {
614
- socket .getSocket ().shutdownInput ();
615
- }
616
- if (how == 1 || how == 2 ) {
617
- socket .getSocket ().shutdownOutput ();
618
- }
619
- } catch (IOException e ) {
620
- throw raise (OSError );
621
- }
622
- } else {
623
+ Object family (VirtualFrame frame , PSocket socket , int how ) {
624
+ if (socket .getSocket () == null ) {
625
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
626
+ }
627
+ try {
628
+ shutdown (socket , how );
629
+ } catch (NotYetConnectedException e ) {
630
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
631
+ } catch (IOException e ) {
623
632
throw raise (OSError );
624
633
}
625
634
return PNone .NO_VALUE ;
626
635
}
636
+
637
+ @ TruffleBoundary
638
+ private static void shutdown (PSocket socket , int how ) throws IOException {
639
+ if (how == 0 || how == 2 ) {
640
+ socket .getSocket ().shutdownInput ();
641
+ }
642
+ if (how == 1 || how == 2 ) {
643
+ socket .getSocket ().shutdownOutput ();
644
+ }
645
+ }
627
646
}
628
647
629
648
// family
0 commit comments