44import io
55import json
66from http .cookies import SimpleCookie
7- from typing import Any , List
7+ from typing import Any , Awaitable , Callable , List
88from unittest import mock
99from uuid import uuid4
1010
1616import aiohttp
1717from aiohttp import client , hdrs , web
1818from aiohttp .client import ClientSession
19+ from aiohttp .client_proto import ResponseHandler
1920from aiohttp .client_reqrep import ClientRequest
20- from aiohttp .connector import BaseConnector , TCPConnector
21+ from aiohttp .connector import BaseConnector , Connection , TCPConnector , UnixConnector
2122from aiohttp .helpers import DEBUG
2223from aiohttp .test_utils import make_mocked_coro
24+ from aiohttp .tracing import Trace
2325
2426
2527@pytest .fixture
@@ -487,15 +489,17 @@ async def test_ws_connect_allowed_protocols(
487489 hdrs .CONNECTION : "upgrade" ,
488490 hdrs .SEC_WEBSOCKET_ACCEPT : ws_key ,
489491 }
490- resp .url = URL (f"{ protocol } ://example.com " )
492+ resp .url = URL (f"{ protocol } ://example" )
491493 resp .cookies = SimpleCookie ()
492494 resp .start = mock .AsyncMock ()
493495
494496 req = mock .create_autospec (aiohttp .ClientRequest , spec_set = True )
495497 req_factory = mock .Mock (return_value = req )
496498 req .send = mock .AsyncMock (return_value = resp )
499+ # BaseConnector allows all high level protocols by default
500+ connector = BaseConnector ()
497501
498- session = await create_session (request_class = req_factory )
502+ session = await create_session (connector = connector , request_class = req_factory )
499503
500504 connections = []
501505 original_connect = session ._connector .connect
@@ -515,7 +519,68 @@ async def create_connection(req, traces, timeout):
515519 "aiohttp.client.os"
516520 ) as m_os :
517521 m_os .urandom .return_value = key_data
518- await session .ws_connect (f"{ protocol } ://example.com" )
522+ await session .ws_connect (f"{ protocol } ://example" )
523+
524+ # normally called during garbage collection. triggers an exception
525+ # if the connection wasn't already closed
526+ for c in connections :
527+ c .close ()
528+ c .__del__ ()
529+
530+ await session .close ()
531+
532+
533+ @pytest .mark .parametrize ("protocol" , ["http" , "https" , "ws" , "wss" , "unix" ])
534+ async def test_ws_connect_unix_socket_allowed_protocols (
535+ create_session : Callable [..., Awaitable [ClientSession ]],
536+ create_mocked_conn : Callable [[], ResponseHandler ],
537+ protocol : str ,
538+ ws_key : bytes ,
539+ key_data : bytes ,
540+ ) -> None :
541+ resp = mock .create_autospec (aiohttp .ClientResponse )
542+ resp .status = 101
543+ resp .headers = {
544+ hdrs .UPGRADE : "websocket" ,
545+ hdrs .CONNECTION : "upgrade" ,
546+ hdrs .SEC_WEBSOCKET_ACCEPT : ws_key ,
547+ }
548+ resp .url = URL (f"{ protocol } ://example" )
549+ resp .cookies = SimpleCookie ()
550+ resp .start = mock .AsyncMock ()
551+
552+ req = mock .create_autospec (aiohttp .ClientRequest , spec_set = True )
553+ req_factory = mock .Mock (return_value = req )
554+ req .send = mock .AsyncMock (return_value = resp )
555+ # UnixConnector allows all high level protocols by default and unix sockets
556+ session = await create_session (
557+ connector = UnixConnector (path = "" ), request_class = req_factory
558+ )
559+
560+ connections = []
561+ assert session ._connector is not None
562+ original_connect = session ._connector .connect
563+
564+ async def connect (
565+ req : ClientRequest , traces : List [Trace ], timeout : aiohttp .ClientTimeout
566+ ) -> Connection :
567+ conn = await original_connect (req , traces , timeout )
568+ connections .append (conn )
569+ return conn
570+
571+ async def create_connection (
572+ req : object , traces : object , timeout : object
573+ ) -> ResponseHandler :
574+ return create_mocked_conn ()
575+
576+ connector = session ._connector
577+ with mock .patch .object (connector , "connect" , connect ), mock .patch .object (
578+ connector , "_create_connection" , create_connection
579+ ), mock .patch .object (connector , "_release" ), mock .patch (
580+ "aiohttp.client.os"
581+ ) as m_os :
582+ m_os .urandom .return_value = key_data
583+ await session .ws_connect (f"{ protocol } ://example" )
519584
520585 # normally called during garbage collection. triggers an exception
521586 # if the connection wasn't already closed
0 commit comments