44import anyio
55
66from .._exceptions import (
7- BrokenSocketError ,
87 ConnectError ,
98 ConnectTimeout ,
109 ReadError ,
@@ -83,9 +82,6 @@ async def start_tls(
8382 return AnyIOStream (ssl_stream )
8483
8584 def get_extra_info (self , info : str ) -> typing .Any :
86- if info == "is_readable" :
87- sock = self ._stream .extra (anyio .abc .SocketAttribute .raw_socket , None )
88- return is_socket_readable (sock )
8985 if info == "ssl_object" :
9086 return self ._stream .extra (anyio .streams .tls .TLSAttribute .ssl_object , None )
9187 if info == "client_addr" :
@@ -94,6 +90,9 @@ def get_extra_info(self, info: str) -> typing.Any:
9490 return self ._stream .extra (anyio .abc .SocketAttribute .remote_address , None )
9591 if info == "socket" :
9692 return self ._stream .extra (anyio .abc .SocketAttribute .raw_socket , None )
93+ if info == "is_readable" :
94+ sock = self ._stream .extra (anyio .abc .SocketAttribute .raw_socket , None )
95+ return is_socket_readable (sock )
9796 return None
9897
9998
@@ -106,6 +105,8 @@ async def connect_tcp(
106105 local_address : typing .Optional [str ] = None ,
107106 socket_options : typing .Optional [typing .Iterable [SOCKET_OPTION ]] = None ,
108107 ) -> AsyncNetworkStream :
108+ if socket_options is None :
109+ socket_options = [] # pragma: no cover
109110 exc_map = {
110111 TimeoutError : ConnectTimeout ,
111112 OSError : ConnectError ,
@@ -119,15 +120,18 @@ async def connect_tcp(
119120 local_host = local_address ,
120121 )
121122 # By default TCP sockets opened in `asyncio` include TCP_NODELAY.
122- self ._set_socket_options (stream , socket_options )
123+ for option in socket_options :
124+ stream ._raw_socket .setsockopt (* option ) # type: ignore[attr-defined] # pragma: no cover
123125 return AnyIOStream (stream )
124126
125127 async def connect_unix_socket (
126128 self ,
127129 path : str ,
128130 timeout : typing .Optional [float ] = None ,
129131 socket_options : typing .Optional [typing .Iterable [SOCKET_OPTION ]] = None ,
130- ) -> AsyncNetworkStream :
132+ ) -> AsyncNetworkStream : # pragma: nocover
133+ if socket_options is None :
134+ socket_options = []
131135 exc_map = {
132136 TimeoutError : ConnectTimeout ,
133137 OSError : ConnectError ,
@@ -136,23 +140,9 @@ async def connect_unix_socket(
136140 with map_exceptions (exc_map ):
137141 with anyio .fail_after (timeout ):
138142 stream : anyio .abc .ByteStream = await anyio .connect_unix (path )
139- self ._set_socket_options (stream , socket_options )
143+ for option in socket_options :
144+ stream ._raw_socket .setsockopt (* option ) # type: ignore[attr-defined] # pragma: no cover
140145 return AnyIOStream (stream )
141146
142147 async def sleep (self , seconds : float ) -> None :
143148 await anyio .sleep (seconds ) # pragma: nocover
144-
145- def _set_socket_options (
146- self ,
147- stream : anyio .abc .ByteStream ,
148- socket_options : typing .Optional [typing .Iterable [SOCKET_OPTION ]] = None ,
149- ) -> None :
150- if not socket_options :
151- return
152-
153- sock = stream .extra (anyio .abc .SocketAttribute .raw_socket , None )
154- if sock is None :
155- raise BrokenSocketError () # pragma: nocover
156-
157- for option in socket_options :
158- sock .setsockopt (* option )
0 commit comments