22from typing import Optional , Tuple , cast
33
44from .._backends .auto import AsyncBackend , AsyncLock , AsyncSocketStream , AutoBackend
5+ from .._exceptions import ConnectError , ConnectTimeout
56from .._types import URL , Headers , Origin , TimeoutDict
6- from .._utils import get_logger , url_to_origin
7+ from .._utils import exponential_backoff , get_logger , url_to_origin
78from .base import (
89 AsyncByteStream ,
910 AsyncHTTPTransport ,
1415
1516logger = get_logger (__name__ )
1617
18+ RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
19+
1720
1821class AsyncHTTPConnection (AsyncHTTPTransport ):
1922 def __init__ (
@@ -24,6 +27,7 @@ def __init__(
2427 ssl_context : SSLContext = None ,
2528 socket : AsyncSocketStream = None ,
2629 local_address : str = None ,
30+ retries : int = 0 ,
2731 backend : AsyncBackend = None ,
2832 ):
2933 self .origin = origin
@@ -32,6 +36,7 @@ def __init__(
3236 self .ssl_context = SSLContext () if ssl_context is None else ssl_context
3337 self .socket = socket
3438 self .local_address = local_address
39+ self .retries = retries
3540
3641 if self .http2 :
3742 self .ssl_context .set_alpn_protocols (["http/1.1" , "h2" ])
@@ -103,22 +108,34 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
103108 scheme , hostname , port = self .origin
104109 timeout = {} if timeout is None else timeout
105110 ssl_context = self .ssl_context if scheme == b"https" else None
106- try :
107- if self .uds is None :
108- return await self .backend .open_tcp_stream (
109- hostname ,
110- port ,
111- ssl_context ,
112- timeout ,
113- local_address = self .local_address ,
114- )
115- else :
116- return await self .backend .open_uds_stream (
117- self .uds , hostname , ssl_context , timeout
118- )
119- except Exception : # noqa: PIE786
120- self .connect_failed = True
121- raise
111+
112+ retries_left = self .retries
113+ delays = exponential_backoff (factor = RETRIES_BACKOFF_FACTOR )
114+
115+ while True :
116+ try :
117+ if self .uds is None :
118+ return await self .backend .open_tcp_stream (
119+ hostname ,
120+ port ,
121+ ssl_context ,
122+ timeout ,
123+ local_address = self .local_address ,
124+ )
125+ else :
126+ return await self .backend .open_uds_stream (
127+ self .uds , hostname , ssl_context , timeout
128+ )
129+ except (ConnectError , ConnectTimeout ):
130+ if retries_left <= 0 :
131+ self .connect_failed = True
132+ raise
133+ retries_left -= 1
134+ delay = next (delays )
135+ await self .backend .sleep (delay )
136+ except Exception : # noqa: PIE786
137+ self .connect_failed = True
138+ raise
122139
123140 def _create_connection (self , socket : AsyncSocketStream ) -> None :
124141 http_version = socket .get_http_version ()
0 commit comments