2828from aiohttp .client import ClientSession
2929from aiohttp .client_proto import ResponseHandler
3030from aiohttp .client_reqrep import ClientRequest , ConnectionKey
31- from aiohttp .connector import BaseConnector , Connection , TCPConnector
31+ from aiohttp .connector import BaseConnector , Connection , TCPConnector , UnixConnector
3232from aiohttp .pytest_plugin import AiohttpClient
3333from aiohttp .test_utils import make_mocked_coro
3434from aiohttp .tracing import Trace
@@ -536,15 +536,78 @@ async def test_ws_connect_allowed_protocols(
536536 hdrs .CONNECTION : "upgrade" ,
537537 hdrs .SEC_WEBSOCKET_ACCEPT : ws_key ,
538538 }
539- resp .url = URL (f"{ protocol } ://example.com " )
539+ resp .url = URL (f"{ protocol } ://example" )
540540 resp .cookies = SimpleCookie ()
541541 resp .start = mock .AsyncMock ()
542542
543543 req = mock .create_autospec (aiohttp .ClientRequest , spec_set = True )
544544 req_factory = mock .Mock (return_value = req )
545545 req .send = mock .AsyncMock (return_value = resp )
546+ # BaseConnector allows all high level protocols by default
547+ connector = BaseConnector ()
546548
547- session = await create_session (request_class = req_factory )
549+ session = await create_session (connector = connector , request_class = req_factory )
550+
551+ connections = []
552+ assert session ._connector is not None
553+ original_connect = session ._connector .connect
554+
555+ async def connect (
556+ req : ClientRequest , traces : List [Trace ], timeout : aiohttp .ClientTimeout
557+ ) -> Connection :
558+ conn = await original_connect (req , traces , timeout )
559+ connections .append (conn )
560+ return conn
561+
562+ async def create_connection (
563+ req : object , traces : object , timeout : object
564+ ) -> ResponseHandler :
565+ return create_mocked_conn ()
566+
567+ connector = session ._connector
568+ with mock .patch .object (connector , "connect" , connect ), mock .patch .object (
569+ connector , "_create_connection" , create_connection
570+ ), mock .patch .object (connector , "_release" ), mock .patch (
571+ "aiohttp.client.os"
572+ ) as m_os :
573+ m_os .urandom .return_value = key_data
574+ await session .ws_connect (f"{ protocol } ://example" )
575+
576+ # normally called during garbage collection. triggers an exception
577+ # if the connection wasn't already closed
578+ for c in connections :
579+ c .close ()
580+ c .__del__ ()
581+
582+ await session .close ()
583+
584+
585+ @pytest .mark .parametrize ("protocol" , ["http" , "https" , "ws" , "wss" , "unix" ])
586+ async def test_ws_connect_unix_socket_allowed_protocols (
587+ create_session : Callable [..., Awaitable [ClientSession ]],
588+ create_mocked_conn : Callable [[], ResponseHandler ],
589+ protocol : str ,
590+ ws_key : bytes ,
591+ key_data : bytes ,
592+ ) -> None :
593+ resp = mock .create_autospec (aiohttp .ClientResponse )
594+ resp .status = 101
595+ resp .headers = {
596+ hdrs .UPGRADE : "websocket" ,
597+ hdrs .CONNECTION : "upgrade" ,
598+ hdrs .SEC_WEBSOCKET_ACCEPT : ws_key ,
599+ }
600+ resp .url = URL (f"{ protocol } ://example" )
601+ resp .cookies = SimpleCookie ()
602+ resp .start = mock .AsyncMock ()
603+
604+ req = mock .create_autospec (aiohttp .ClientRequest , spec_set = True )
605+ req_factory = mock .Mock (return_value = req )
606+ req .send = mock .AsyncMock (return_value = resp )
607+ # UnixConnector allows all high level protocols by default and unix sockets
608+ session = await create_session (
609+ connector = UnixConnector (path = "" ), request_class = req_factory
610+ )
548611
549612 connections = []
550613 assert session ._connector is not None
@@ -569,7 +632,7 @@ async def create_connection(
569632 "aiohttp.client.os"
570633 ) as m_os :
571634 m_os .urandom .return_value = key_data
572- await session .ws_connect (f"{ protocol } ://example.com " )
635+ await session .ws_connect (f"{ protocol } ://example" )
573636
574637 # normally called during garbage collection. triggers an exception
575638 # if the connection wasn't already closed
0 commit comments