Skip to content

Commit b2691f2

Browse files
Fix connecting to npipe://, tcp://, and unix:// urls (#8632)
Co-authored-by: Sam Bull <[email protected]>
1 parent e0ff524 commit b2691f2

File tree

5 files changed

+128
-11
lines changed

5 files changed

+128
-11
lines changed

CHANGES/8632.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`.

aiohttp/client.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@
8080
ClientWebSocketResponse,
8181
ClientWSTimeout,
8282
)
83-
from .connector import BaseConnector, NamedPipeConnector, TCPConnector, UnixConnector
83+
from .connector import (
84+
HTTP_AND_EMPTY_SCHEMA_SET,
85+
BaseConnector,
86+
NamedPipeConnector,
87+
TCPConnector,
88+
UnixConnector,
89+
)
8490
from .cookiejar import CookieJar
8591
from .helpers import (
8692
_SENTINEL,
@@ -210,9 +216,6 @@ class ClientTimeout:
210216

211217
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
212218
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
213-
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
214-
WS_SCHEMA_SET = frozenset({"ws", "wss"})
215-
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET
216219

217220
_RetType = TypeVar("_RetType")
218221
_CharsetResolver = Callable[[ClientResponse, bytes], str]
@@ -466,7 +469,8 @@ async def _request(
466469
except ValueError as e:
467470
raise InvalidUrlClientError(str_or_url) from e
468471

469-
if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
472+
assert self._connector is not None
473+
if url.scheme not in self._connector.allowed_protocol_schema_set:
470474
raise NonHttpUrlClientError(url)
471475

472476
skip_headers = set(self._skip_auto_headers)
@@ -597,7 +601,6 @@ async def _request(
597601
real_timeout.connect,
598602
ceil_threshold=real_timeout.ceil_threshold,
599603
):
600-
assert self._connector is not None
601604
conn = await self._connector.connect(
602605
req, traces=traces, timeout=real_timeout
603606
)
@@ -693,7 +696,7 @@ async def _request(
693696
) from e
694697

695698
scheme = parsed_redirect_url.scheme
696-
if scheme not in HTTP_SCHEMA_SET:
699+
if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
697700
resp.close()
698701
raise NonHttpUrlRedirectClientError(r_url)
699702
elif not scheme:

aiohttp/connector.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@
6464
SSLContext = object # type: ignore[misc,assignment]
6565

6666

67+
EMPTY_SCHEMA_SET = frozenset({""})
68+
HTTP_SCHEMA_SET = frozenset({"http", "https"})
69+
WS_SCHEMA_SET = frozenset({"ws", "wss"})
70+
71+
HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
72+
HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET
73+
74+
6775
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
6876

6977

@@ -190,6 +198,8 @@ class BaseConnector:
190198
# abort transport after 2 seconds (cleanup broken connections)
191199
_cleanup_closed_period = 2.0
192200

201+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET
202+
193203
def __init__(
194204
self,
195205
*,
@@ -741,6 +751,8 @@ class TCPConnector(BaseConnector):
741751
loop - Optional event loop.
742752
"""
743753

754+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
755+
744756
def __init__(
745757
self,
746758
*,
@@ -1342,6 +1354,8 @@ class UnixConnector(BaseConnector):
13421354
loop - Optional event loop.
13431355
"""
13441356

1357+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})
1358+
13451359
def __init__(
13461360
self,
13471361
path: str,
@@ -1396,6 +1410,8 @@ class NamedPipeConnector(BaseConnector):
13961410
loop - Optional event loop.
13971411
"""
13981412

1413+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})
1414+
13991415
def __init__(
14001416
self,
14011417
path: str,

tests/test_client_session.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from aiohttp.client import ClientSession
2929
from aiohttp.client_proto import ResponseHandler
3030
from aiohttp.client_reqrep import ClientRequest, ConnectionKey
31-
from aiohttp.connector import BaseConnector, Connection, TCPConnector
31+
from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector
3232
from aiohttp.pytest_plugin import AiohttpClient
3333
from aiohttp.test_utils import make_mocked_coro
3434
from aiohttp.tracing import Trace
@@ -536,15 +536,78 @@ async def test_ws_connect_allowed_protocols(
536536
hdrs.CONNECTION: "upgrade",
537537
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
538538
}
539-
resp.url = URL(f"{protocol}://example.com")
539+
resp.url = URL(f"{protocol}://example")
540540
resp.cookies = SimpleCookie()
541541
resp.start = mock.AsyncMock()
542542

543543
req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
544544
req_factory = mock.Mock(return_value=req)
545545
req.send = mock.AsyncMock(return_value=resp)
546+
# BaseConnector allows all high level protocols by default
547+
connector = BaseConnector()
546548

547-
session = await create_session(request_class=req_factory)
549+
session = await create_session(connector=connector, request_class=req_factory)
550+
551+
connections = []
552+
assert session._connector is not None
553+
original_connect = session._connector.connect
554+
555+
async def connect(
556+
req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout
557+
) -> Connection:
558+
conn = await original_connect(req, traces, timeout)
559+
connections.append(conn)
560+
return conn
561+
562+
async def create_connection(
563+
req: object, traces: object, timeout: object
564+
) -> ResponseHandler:
565+
return create_mocked_conn()
566+
567+
connector = session._connector
568+
with mock.patch.object(connector, "connect", connect), mock.patch.object(
569+
connector, "_create_connection", create_connection
570+
), mock.patch.object(connector, "_release"), mock.patch(
571+
"aiohttp.client.os"
572+
) as m_os:
573+
m_os.urandom.return_value = key_data
574+
await session.ws_connect(f"{protocol}://example")
575+
576+
# normally called during garbage collection. triggers an exception
577+
# if the connection wasn't already closed
578+
for c in connections:
579+
c.close()
580+
c.__del__()
581+
582+
await session.close()
583+
584+
585+
@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"])
586+
async def test_ws_connect_unix_socket_allowed_protocols(
587+
create_session: Callable[..., Awaitable[ClientSession]],
588+
create_mocked_conn: Callable[[], ResponseHandler],
589+
protocol: str,
590+
ws_key: bytes,
591+
key_data: bytes,
592+
) -> None:
593+
resp = mock.create_autospec(aiohttp.ClientResponse)
594+
resp.status = 101
595+
resp.headers = {
596+
hdrs.UPGRADE: "websocket",
597+
hdrs.CONNECTION: "upgrade",
598+
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
599+
}
600+
resp.url = URL(f"{protocol}://example")
601+
resp.cookies = SimpleCookie()
602+
resp.start = mock.AsyncMock()
603+
604+
req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
605+
req_factory = mock.Mock(return_value=req)
606+
req.send = mock.AsyncMock(return_value=resp)
607+
# UnixConnector allows all high level protocols by default and unix sockets
608+
session = await create_session(
609+
connector=UnixConnector(path=""), request_class=req_factory
610+
)
548611

549612
connections = []
550613
assert session._connector is not None
@@ -569,7 +632,7 @@ async def create_connection(
569632
"aiohttp.client.os"
570633
) as m_os:
571634
m_os.urandom.return_value = key_data
572-
await session.ws_connect(f"{protocol}://example.com")
635+
await session.ws_connect(f"{protocol}://example")
573636

574637
# normally called during garbage collection. triggers an exception
575638
# if the connection wasn't already closed

tests/test_connector.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,11 @@ async def test_tcp_connector_ctor(loop: asyncio.AbstractEventLoop) -> None:
16361636
assert conn.family == 0
16371637

16381638

1639+
async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None:
1640+
conn = aiohttp.TCPConnector()
1641+
assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"}
1642+
1643+
16391644
async def test_invalid_ssl_param() -> None:
16401645
with pytest.raises(TypeError):
16411646
aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type]
@@ -1819,6 +1824,19 @@ async def test_ctor_with_default_loop(loop: asyncio.AbstractEventLoop) -> None:
18191824
assert loop is conn._loop
18201825

18211826

1827+
async def test_base_connector_allows_high_level_protocols(
1828+
loop: asyncio.AbstractEventLoop,
1829+
) -> None:
1830+
conn = aiohttp.BaseConnector()
1831+
assert conn.allowed_protocol_schema_set == {
1832+
"",
1833+
"http",
1834+
"https",
1835+
"ws",
1836+
"wss",
1837+
}
1838+
1839+
18221840
async def test_connect_with_limit(
18231841
loop: asyncio.AbstractEventLoop, key: ConnectionKey
18241842
) -> None:
@@ -2621,6 +2639,14 @@ async def handler(request: web.Request) -> web.Response:
26212639

26222640
connector = aiohttp.UnixConnector(unix_sockname)
26232641
assert unix_sockname == connector.path
2642+
assert connector.allowed_protocol_schema_set == {
2643+
"",
2644+
"http",
2645+
"https",
2646+
"ws",
2647+
"wss",
2648+
"unix",
2649+
}
26242650

26252651
session = ClientSession(connector=connector)
26262652
r = await session.get(url)
@@ -2648,6 +2674,14 @@ async def handler(request: web.Request) -> web.Response:
26482674

26492675
connector = aiohttp.NamedPipeConnector(pipe_name)
26502676
assert pipe_name == connector.path
2677+
assert connector.allowed_protocol_schema_set == {
2678+
"",
2679+
"http",
2680+
"https",
2681+
"ws",
2682+
"wss",
2683+
"npipe",
2684+
}
26512685

26522686
session = ClientSession(connector=connector)
26532687
r = await session.get(url)

0 commit comments

Comments
 (0)