Skip to content

Commit 3b9bb1c

Browse files
TimMenningerbdraco
andauthored
Replace tcp_sockopts with socket_factory (#10534)
Instead of TCPConnector taking a list of sockopts to be applied sockets created, take a socket_factory callback that allows the caller to implement socket creation entirely. Fixes #10520 <!-- Thank you for your contribution! --> ## What do these changes do? Replace `tcp_sockopts` parameter with a `socket_factory` parameter that is a callback allowing the caller to own socket creation. If passed, all sockets created by `TCPConnector` are expected to come from the `socket_factory` callback. <!-- Please give a short brief about these changes. --> ## Are there changes in behavior for the user? The only users to experience a change in behavior are those who are using the un-released `tcp_sockopts` argument to `TCPConnector`. However, using unreleased code comes with caveat emptor, and is why I felt entitled to remove the option entirely without warning. <!-- Outline any notable behaviour for the end users. --> ## Is it a substantial burden for the maintainers to support this? The burden will be minimal and would only arise if `aiohappyeyeballs` changes their interface. <!-- Stop right there! Pause. Just for a minute... Can you think of anything obvious that would complicate the ongoing development of this project? Try to consider if you'd be able to maintain it throughout the next 5 years. Does it seem viable? Tell us your thoughts! We'd very much love to hear what the consequences of merging this patch might be... This will help us assess if your change is something we'd want to entertain early in the review process. Thank you in advance! --> ## Related issue number <!-- Are there any issues opened that will be resolved by merging this change? --> <!-- Remember to prefix with 'Fixes' if it should close the issue (e.g. 'Fixes #123'). --> ## Checklist - [x] I think the code is well written - [x] Unit tests for the changes exist - [x] Documentation reflects the changes - [x] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is &lt;Name&gt; &lt;Surname&gt;. * Please keep alphabetical order, the file is sorted by names. - [x] Add a new news fragment into the `CHANGES/` folder * name it `<issue_or_pr_num>.<type>.rst` (e.g. `588.bugfix.rst`) * if you don't have an issue number, change it to the pull request number after creating the PR * `.bugfix`: A bug fix for something the maintainers deemed an improper undesired behavior that got corrected to match pre-agreed expectations. * `.feature`: A new behavior, public APIs. That sort of stuff. * `.deprecation`: A declaration of future API removals and breaking changes in behavior. * `.breaking`: When something public is removed in a breaking way. Could be deprecated in an earlier release. * `.doc`: Notable updates to the documentation structure or build process. * `.packaging`: Notes for downstreams about unobvious side effects and tooling. Changes in the test invocation considerations and runtime assumptions. * `.contrib`: Stuff that affects the contributor experience. e.g. Running tests, building the docs, setting up the development environment. * `.misc`: Changes that are hard to assign to any of the above categories. * Make sure to use full sentences with correct case and punctuation, for example: ```rst Fixed issue with non-ascii contents in doctest text files -- by :user:`contributor-gh-handle`. ``` Use the past tense or the present tense a non-imperative mood, referring to what's changed compared to the last released version of this project. --------- Co-authored-by: J. Nick Koston <[email protected]>
1 parent 492f63d commit 3b9bb1c

File tree

10 files changed

+113
-44
lines changed

10 files changed

+113
-44
lines changed

CHANGES/10474.feature.rst

Lines changed: 0 additions & 2 deletions
This file was deleted.

CHANGES/10520.feature.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added ``socket_factory`` to :py:class:`aiohttp.TCPConnector` to allow specifying custom socket options
2+
-- by :user:`TimMenninger`.

aiohttp/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WSServerHandshakeError,
4848
request,
4949
)
50+
from .connector import AddrInfoType, SocketFactoryType
5051
from .cookiejar import CookieJar, DummyCookieJar
5152
from .formdata import FormData
5253
from .helpers import BasicAuth, ChainMapProxy, ETag
@@ -112,6 +113,7 @@
112113
__all__: Tuple[str, ...] = (
113114
"hdrs",
114115
# client
116+
"AddrInfoType",
115117
"BaseConnector",
116118
"ClientConnectionError",
117119
"ClientConnectionResetError",
@@ -146,6 +148,7 @@
146148
"ServerDisconnectedError",
147149
"ServerFingerprintMismatch",
148150
"ServerTimeoutError",
151+
"SocketFactoryType",
149152
"SocketTimeoutError",
150153
"TCPConnector",
151154
"TooManyRedirects",

aiohttp/connector.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
DefaultDict,
2121
Deque,
2222
Dict,
23-
Iterable,
2423
Iterator,
2524
List,
2625
Literal,
@@ -34,6 +33,7 @@
3433
)
3534

3635
import aiohappyeyeballs
36+
from aiohappyeyeballs import AddrInfoType, SocketFactoryType
3737

3838
from . import hdrs, helpers
3939
from .abc import AbstractResolver, ResolveResult
@@ -96,7 +96,14 @@
9696
# which first appeared in Python 3.12.7 and 3.13.1
9797

9898

99-
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
99+
__all__ = (
100+
"BaseConnector",
101+
"TCPConnector",
102+
"UnixConnector",
103+
"NamedPipeConnector",
104+
"AddrInfoType",
105+
"SocketFactoryType",
106+
)
100107

101108

102109
if TYPE_CHECKING:
@@ -826,8 +833,9 @@ class TCPConnector(BaseConnector):
826833
the happy eyeballs algorithm, set to None.
827834
interleave - “First Address Family Count” as defined in RFC 8305
828835
loop - Optional event loop.
829-
tcp_sockopts - List of tuples of sockopts applied to underlying
830-
socket
836+
socket_factory - A SocketFactoryType function that, if supplied,
837+
will be used to create sockets given an
838+
AddrInfoType.
831839
"""
832840

833841
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
@@ -849,7 +857,7 @@ def __init__(
849857
timeout_ceil_threshold: float = 5,
850858
happy_eyeballs_delay: Optional[float] = 0.25,
851859
interleave: Optional[int] = None,
852-
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
860+
socket_factory: Optional[SocketFactoryType] = None,
853861
):
854862
super().__init__(
855863
keepalive_timeout=keepalive_timeout,
@@ -880,7 +888,7 @@ def __init__(
880888
self._happy_eyeballs_delay = happy_eyeballs_delay
881889
self._interleave = interleave
882890
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
883-
self._tcp_sockopts = tcp_sockopts
891+
self._socket_factory = socket_factory
884892

885893
def _close_immediately(self) -> List[Awaitable[object]]:
886894
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
@@ -1105,7 +1113,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
11051113
async def _wrap_create_connection(
11061114
self,
11071115
*args: Any,
1108-
addr_infos: List[aiohappyeyeballs.AddrInfoType],
1116+
addr_infos: List[AddrInfoType],
11091117
req: ClientRequest,
11101118
timeout: "ClientTimeout",
11111119
client_error: Type[Exception] = ClientConnectorError,
@@ -1122,9 +1130,8 @@ async def _wrap_create_connection(
11221130
happy_eyeballs_delay=self._happy_eyeballs_delay,
11231131
interleave=self._interleave,
11241132
loop=self._loop,
1133+
socket_factory=self._socket_factory,
11251134
)
1126-
for sockopt in self._tcp_sockopts:
1127-
sock.setsockopt(*sockopt)
11281135
connection = await self._loop.create_connection(
11291136
*args, **kwargs, sock=sock
11301137
)
@@ -1256,13 +1263,13 @@ async def _start_tls_connection(
12561263

12571264
def _convert_hosts_to_addr_infos(
12581265
self, hosts: List[ResolveResult]
1259-
) -> List[aiohappyeyeballs.AddrInfoType]:
1266+
) -> List[AddrInfoType]:
12601267
"""Converts the list of hosts to a list of addr_infos.
12611268
12621269
The list of hosts is the result of a DNS lookup. The list of
12631270
addr_infos is the result of a call to `socket.getaddrinfo()`.
12641271
"""
1265-
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
1272+
addr_infos: List[AddrInfoType] = []
12661273
for hinfo in hosts:
12671274
host = hinfo["host"]
12681275
is_ipv6 = ":" in host

docs/client_advanced.rst

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use
468468
session = aiohttp.ClientSession(connector=conn)
469469

470470

471-
Setting socket options
471+
Custom socket creation
472472
^^^^^^^^^^^^^^^^^^^^^^
473473

474-
Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
475-
to the underlying socket when creating a connection. For example, we may
476-
want to change the conditions under which we consider a connection dead.
477-
The following would change that to 9*7200 = 18 hours::
474+
If the default socket is insufficient for your use case, pass an optional
475+
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
476+
`SocketFactoryType`. This will be used to create all sockets for the
477+
lifetime of the class object. For example, we may want to change the
478+
conditions under which we consider a connection dead. The following would
479+
make all sockets respect 9*7200 = 18 hours::
478480

479481
import socket
480482

481-
conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
482-
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
483-
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
483+
def socket_factory(addr_info):
484+
family, type_, proto, _, _, _ = addr_info
485+
sock = socket.socket(family=family, type=type_, proto=proto)
486+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
487+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
488+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
489+
return sock
490+
conn = aiohttp.TCPConnector(socket_factory=socket_factory)
484491

485492

486493
Named pipes in Windows

docs/client_reference.rst

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,14 +1122,42 @@ is controlled by *force_close* constructor's parameter).
11221122
overridden in subclasses.
11231123

11241124

1125+
.. autodata:: AddrInfoType
1126+
1127+
.. note::
1128+
1129+
Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info.
1130+
1131+
.. warning::
1132+
1133+
Be sure to use ``aiohttp.AddrInfoType`` rather than
1134+
``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as
1135+
it is likely to be removed from ``aiohappyeyeballs`` in the
1136+
future.
1137+
1138+
1139+
.. autodata:: SocketFactoryType
1140+
1141+
.. note::
1142+
1143+
Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info.
1144+
1145+
.. warning::
1146+
1147+
Be sure to use ``aiohttp.SocketFactoryType`` rather than
1148+
``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage,
1149+
as it is likely to be removed from ``aiohappyeyeballs`` in the
1150+
future.
1151+
1152+
11251153
.. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \
11261154
use_dns_cache=True, ttl_dns_cache=10, \
11271155
family=0, ssl_context=None, local_addr=None, \
11281156
resolver=None, keepalive_timeout=sentinel, \
11291157
force_close=False, limit=100, limit_per_host=0, \
11301158
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
11311159
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
1132-
tcp_sockopts=[])
1160+
socket_factory=None)
11331161

11341162
Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.
11351163

@@ -1250,9 +1278,9 @@ is controlled by *force_close* constructor's parameter).
12501278

12511279
.. versionadded:: 3.10
12521280

1253-
:param list tcp_sockopts: options applied to the socket when a connection is
1254-
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
1255-
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
1281+
:param :py:data:``SocketFactoryType`` socket_factory: This function takes an
1282+
:py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when
1283+
creating TCP connections.
12561284

12571285
.. versionadded:: 3.12
12581286

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
# ones.
5454
extensions = [
5555
# stdlib-party extensions:
56+
"sphinx.ext.autodoc",
5657
"sphinx.ext.extlinks",
5758
"sphinx.ext.graphviz",
5859
"sphinx.ext.intersphinx",
@@ -82,6 +83,7 @@
8283
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
8384
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
8485
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
86+
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
8587
}
8688

8789
# Add any paths that contain templates here, relative to this directory.
@@ -425,6 +427,7 @@
425427
("py:class", "cgi.FieldStorage"), # undocumented
426428
("py:meth", "aiohttp.web.UrlDispatcher.register_resource"), # undocumented
427429
("py:func", "aiohttp_debugtoolbar.setup"), # undocumented
430+
("py:class", "socket.SocketKind"), # undocumented
428431
]
429432

430433
# -- Options for towncrier_draft extension -----------------------------------

requirements/runtime-deps.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`
22

33
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
4-
aiohappyeyeballs >= 2.3.0
4+
aiohappyeyeballs >= 2.5.0
55
aiosignal >= 1.1.2
66
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
77
Brotli; platform_python_implementation == 'CPython'

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ zip_safe = False
5151
include_package_data = True
5252

5353
install_requires =
54-
aiohappyeyeballs >= 2.3.0
54+
aiohappyeyeballs >= 2.5.0
5555
aiosignal >= 1.1.2
5656
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
5757
frozenlist >= 1.1.1

tests/test_connector.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from unittest import mock
2727

2828
import pytest
29-
from aiohappyeyeballs import AddrInfoType
3029
from pytest_mock import MockerFixture
3130
from yarl import URL
3231

@@ -44,6 +43,7 @@
4443
from aiohttp.connector import (
4544
_SSL_CONTEXT_UNVERIFIED,
4645
_SSL_CONTEXT_VERIFIED,
46+
AddrInfoType,
4747
Connection,
4848
TCPConnector,
4949
_DNSCacheTable,
@@ -3822,27 +3822,48 @@ def test_connect() -> Literal[True]:
38223822
assert raw_response_list == [True, True]
38233823

38243824

3825-
async def test_tcp_connector_setsockopts(
3825+
async def test_tcp_connector_socket_factory(
38263826
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
38273827
) -> None:
3828-
"""Check that sockopts get passed to socket"""
3829-
conn = aiohttp.TCPConnector(
3830-
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
3831-
)
3832-
3833-
with mock.patch.object(
3834-
conn._loop, "create_connection", autospec=True, spec_set=True
3835-
) as create_connection:
3836-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3837-
start_connection.return_value = s
3838-
create_connection.return_value = mock.Mock(), mock.Mock()
3828+
"""Check that socket factory is called"""
3829+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3830+
start_connection.return_value = s
38393831

3840-
req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
3832+
local_addr = None
3833+
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
3834+
happy_eyeballs_delay = 0.123
3835+
interleave = 3
3836+
conn = aiohttp.TCPConnector(
3837+
interleave=interleave,
3838+
local_addr=local_addr,
3839+
happy_eyeballs_delay=happy_eyeballs_delay,
3840+
socket_factory=socket_factory,
3841+
)
38413842

3843+
with mock.patch.object(
3844+
conn._loop,
3845+
"create_connection",
3846+
autospec=True,
3847+
spec_set=True,
3848+
return_value=(mock.Mock(), mock.Mock()),
3849+
):
3850+
host = "127.0.0.1"
3851+
port = 443
3852+
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
38423853
with closing(await conn.connect(req, [], ClientTimeout())):
3843-
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2
3844-
3845-
await conn.close()
3854+
pass
3855+
await conn.close()
3856+
3857+
start_connection.assert_called_with(
3858+
addr_infos=[
3859+
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
3860+
],
3861+
local_addr_infos=local_addr,
3862+
happy_eyeballs_delay=happy_eyeballs_delay,
3863+
interleave=interleave,
3864+
loop=loop,
3865+
socket_factory=socket_factory,
3866+
)
38463867

38473868

38483869
def test_default_ssl_context_creation_without_ssl() -> None:

0 commit comments

Comments
 (0)