Skip to content

Commit f9f59af

Browse files
Add connect retries (#221)
* Add connect retries * Update tests/async_tests/test_retries.py Co-authored-by: Jamie Hewland <[email protected]> * Unasync Co-authored-by: Jamie Hewland <[email protected]>
1 parent 11f537e commit f9f59af

File tree

15 files changed

+457
-42
lines changed

15 files changed

+457
-42
lines changed

httpcore/_async/connection.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from typing import Optional, Tuple, cast
33

44
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend
5+
from .._exceptions import ConnectError, ConnectTimeout
56
from .._types import URL, Headers, Origin, TimeoutDict
6-
from .._utils import get_logger, url_to_origin
7+
from .._utils import exponential_backoff, get_logger, url_to_origin
78
from .base import (
89
AsyncByteStream,
910
AsyncHTTPTransport,
@@ -14,6 +15,8 @@
1415

1516
logger = get_logger(__name__)
1617

18+
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
19+
1720

1821
class AsyncHTTPConnection(AsyncHTTPTransport):
1922
def __init__(
@@ -24,6 +27,7 @@ def __init__(
2427
ssl_context: SSLContext = None,
2528
socket: AsyncSocketStream = None,
2629
local_address: str = None,
30+
retries: int = 0,
2731
backend: AsyncBackend = None,
2832
):
2933
self.origin = origin
@@ -32,6 +36,7 @@ def __init__(
3236
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
3337
self.socket = socket
3438
self.local_address = local_address
39+
self.retries = retries
3540

3641
if self.http2:
3742
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
@@ -103,22 +108,34 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
103108
scheme, hostname, port = self.origin
104109
timeout = {} if timeout is None else timeout
105110
ssl_context = self.ssl_context if scheme == b"https" else None
106-
try:
107-
if self.uds is None:
108-
return await self.backend.open_tcp_stream(
109-
hostname,
110-
port,
111-
ssl_context,
112-
timeout,
113-
local_address=self.local_address,
114-
)
115-
else:
116-
return await self.backend.open_uds_stream(
117-
self.uds, hostname, ssl_context, timeout
118-
)
119-
except Exception: # noqa: PIE786
120-
self.connect_failed = True
121-
raise
111+
112+
retries_left = self.retries
113+
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
114+
115+
while True:
116+
try:
117+
if self.uds is None:
118+
return await self.backend.open_tcp_stream(
119+
hostname,
120+
port,
121+
ssl_context,
122+
timeout,
123+
local_address=self.local_address,
124+
)
125+
else:
126+
return await self.backend.open_uds_stream(
127+
self.uds, hostname, ssl_context, timeout
128+
)
129+
except (ConnectError, ConnectTimeout):
130+
if retries_left <= 0:
131+
self.connect_failed = True
132+
raise
133+
retries_left -= 1
134+
delay = next(delays)
135+
await self.backend.sleep(delay)
136+
except Exception: # noqa: PIE786
137+
self.connect_failed = True
138+
raise
122139

123140
def _create_connection(self, socket: AsyncSocketStream) -> None:
124141
http_version = socket.get_http_version()

httpcore/_async/connection_pool.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
import warnings
22
from ssl import SSLContext
3-
from typing import AsyncIterator, Callable, Dict, List, Optional, Set, Tuple, cast
3+
from typing import (
4+
AsyncIterator,
5+
Callable,
6+
Dict,
7+
List,
8+
Optional,
9+
Set,
10+
Tuple,
11+
Union,
12+
cast,
13+
)
414

5-
from .._backends.auto import AsyncLock, AsyncSemaphore
15+
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore
616
from .._backends.base import lookup_async_backend
717
from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol
818
from .._threadlock import ThreadLock
@@ -84,6 +94,8 @@ class AsyncConnectionPool(AsyncHTTPTransport):
8494
`local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4),
8595
while using `local_address="::"` will connect using an `AF_INET6` address
8696
(IPv6).
97+
* **retries** - `int` - The maximum number of retries when trying to establish a
98+
connection.
8799
* **backend** - `str` - A name indicating which concurrency backend to use.
88100
"""
89101

@@ -96,8 +108,9 @@ def __init__(
96108
http2: bool = False,
97109
uds: str = None,
98110
local_address: str = None,
111+
retries: int = 0,
99112
max_keepalive: int = None,
100-
backend: str = "auto",
113+
backend: Union[AsyncBackend, str] = "auto",
101114
):
102115
if max_keepalive is not None:
103116
warnings.warn(
@@ -106,16 +119,20 @@ def __init__(
106119
)
107120
max_keepalive_connections = max_keepalive
108121

122+
if isinstance(backend, str):
123+
backend = lookup_async_backend(backend)
124+
109125
self._ssl_context = SSLContext() if ssl_context is None else ssl_context
110126
self._max_connections = max_connections
111127
self._max_keepalive_connections = max_keepalive_connections
112128
self._keepalive_expiry = keepalive_expiry
113129
self._http2 = http2
114130
self._uds = uds
115131
self._local_address = local_address
132+
self._retries = retries
116133
self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {}
117134
self._thread_lock = ThreadLock()
118-
self._backend = lookup_async_backend(backend)
135+
self._backend = backend
119136
self._next_keepalive_check = 0.0
120137

121138
if http2:
@@ -157,6 +174,7 @@ def _create_connection(
157174
uds=self._uds,
158175
ssl_context=self._ssl_context,
159176
local_address=self._local_address,
177+
retries=self._retries,
160178
backend=self._backend,
161179
)
162180

httpcore/_backends/anyio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
192192

193193
async def time(self) -> float:
194194
return await anyio.current_time()
195+
196+
async def sleep(self, seconds: float) -> None:
197+
await anyio.sleep(seconds)

httpcore/_backends/asyncio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
282282
async def time(self) -> float:
283283
loop = asyncio.get_event_loop()
284284
return loop.time()
285+
286+
async def sleep(self, seconds: float) -> None:
287+
await asyncio.sleep(seconds)

httpcore/_backends/auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
6262

6363
async def time(self) -> float:
6464
return await self.backend.time()
65+
66+
async def sleep(self, seconds: float) -> None:
67+
await self.backend.sleep(seconds)

httpcore/_backends/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
132132

133133
async def time(self) -> float:
134134
raise NotImplementedError() # pragma: no cover
135+
136+
async def sleep(self, seconds: float) -> None:
137+
raise NotImplementedError() # pragma: no cover

httpcore/_backends/curio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
204204

205205
async def time(self) -> float:
206206
return await curio.clock()
207+
208+
async def sleep(self, seconds: float) -> None:
209+
await curio.sleep(seconds)

httpcore/_backends/sync.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> SyncSemaphore:
173173

174174
def time(self) -> float:
175175
return time.monotonic()
176+
177+
def sleep(self, seconds: float) -> None:
178+
time.sleep(seconds)

httpcore/_backends/trio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
200200

201201
async def time(self) -> float:
202202
return trio.current_time()
203+
204+
async def sleep(self, seconds: float) -> None:
205+
await trio.sleep(seconds)

httpcore/_sync/connection.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from typing import Optional, Tuple, cast
33

44
from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend
5+
from .._exceptions import ConnectError, ConnectTimeout
56
from .._types import URL, Headers, Origin, TimeoutDict
6-
from .._utils import get_logger, url_to_origin
7+
from .._utils import exponential_backoff, get_logger, url_to_origin
78
from .base import (
89
SyncByteStream,
910
SyncHTTPTransport,
@@ -14,6 +15,8 @@
1415

1516
logger = get_logger(__name__)
1617

18+
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
19+
1720

1821
class SyncHTTPConnection(SyncHTTPTransport):
1922
def __init__(
@@ -24,6 +27,7 @@ def __init__(
2427
ssl_context: SSLContext = None,
2528
socket: SyncSocketStream = None,
2629
local_address: str = None,
30+
retries: int = 0,
2731
backend: SyncBackend = None,
2832
):
2933
self.origin = origin
@@ -32,6 +36,7 @@ def __init__(
3236
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
3337
self.socket = socket
3438
self.local_address = local_address
39+
self.retries = retries
3540

3641
if self.http2:
3742
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
@@ -103,22 +108,34 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
103108
scheme, hostname, port = self.origin
104109
timeout = {} if timeout is None else timeout
105110
ssl_context = self.ssl_context if scheme == b"https" else None
106-
try:
107-
if self.uds is None:
108-
return self.backend.open_tcp_stream(
109-
hostname,
110-
port,
111-
ssl_context,
112-
timeout,
113-
local_address=self.local_address,
114-
)
115-
else:
116-
return self.backend.open_uds_stream(
117-
self.uds, hostname, ssl_context, timeout
118-
)
119-
except Exception: # noqa: PIE786
120-
self.connect_failed = True
121-
raise
111+
112+
retries_left = self.retries
113+
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
114+
115+
while True:
116+
try:
117+
if self.uds is None:
118+
return self.backend.open_tcp_stream(
119+
hostname,
120+
port,
121+
ssl_context,
122+
timeout,
123+
local_address=self.local_address,
124+
)
125+
else:
126+
return self.backend.open_uds_stream(
127+
self.uds, hostname, ssl_context, timeout
128+
)
129+
except (ConnectError, ConnectTimeout):
130+
if retries_left <= 0:
131+
self.connect_failed = True
132+
raise
133+
retries_left -= 1
134+
delay = next(delays)
135+
self.backend.sleep(delay)
136+
except Exception: # noqa: PIE786
137+
self.connect_failed = True
138+
raise
122139

123140
def _create_connection(self, socket: SyncSocketStream) -> None:
124141
http_version = socket.get_http_version()

0 commit comments

Comments
 (0)