1
+ import asyncio
2
+ import functools
1
3
import io
2
4
import json
3
5
import logging
@@ -44,6 +46,7 @@ def __init__(
44
46
auth : Optional [BasicAuth ] = None ,
45
47
ssl : Union [SSLContext , bool , Fingerprint ] = False ,
46
48
timeout : Optional [int ] = None ,
49
+ ssl_close_timeout : Optional [Union [int , float ]] = 10 ,
47
50
client_session_args : Optional [Dict [str , Any ]] = None ,
48
51
) -> None :
49
52
"""Initialize the transport with the given aiohttp parameters.
@@ -53,6 +56,8 @@ def __init__(
53
56
:param cookies: Dict of HTTP cookies.
54
57
:param auth: BasicAuth object to enable Basic HTTP auth if needed
55
58
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
59
+ :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
60
+ to close properly
56
61
:param client_session_args: Dict of extra args passed to
57
62
`aiohttp.ClientSession`_
58
63
@@ -65,6 +70,7 @@ def __init__(
65
70
self .auth : Optional [BasicAuth ] = auth
66
71
self .ssl : Union [SSLContext , bool , Fingerprint ] = ssl
67
72
self .timeout : Optional [int ] = timeout
73
+ self .ssl_close_timeout : Optional [Union [int , float ]] = ssl_close_timeout
68
74
self .client_session_args = client_session_args
69
75
self .session : Optional [aiohttp .ClientSession ] = None
70
76
@@ -100,6 +106,59 @@ async def connect(self) -> None:
100
106
else :
101
107
raise TransportAlreadyConnected ("Transport is already connected" )
102
108
109
+ @staticmethod
110
+ def create_aiohttp_closed_event (session ) -> asyncio .Event :
111
+ """Work around aiohttp issue that doesn't properly close transports on exit.
112
+
113
+ See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
114
+
115
+ Returns:
116
+ An event that will be set once all transports have been properly closed.
117
+ """
118
+
119
+ ssl_transports = 0
120
+ all_is_lost = asyncio .Event ()
121
+
122
+ def connection_lost (exc , orig_lost ):
123
+ nonlocal ssl_transports
124
+
125
+ try :
126
+ orig_lost (exc )
127
+ finally :
128
+ ssl_transports -= 1
129
+ if ssl_transports == 0 :
130
+ all_is_lost .set ()
131
+
132
+ def eof_received (orig_eof_received ):
133
+ try :
134
+ orig_eof_received ()
135
+ except AttributeError : # pragma: no cover
136
+ # It may happen that eof_received() is called after
137
+ # _app_protocol and _transport are set to None.
138
+ pass
139
+
140
+ for conn in session .connector ._conns .values ():
141
+ for handler , _ in conn :
142
+ proto = getattr (handler .transport , "_ssl_protocol" , None )
143
+ if proto is None :
144
+ continue
145
+
146
+ ssl_transports += 1
147
+ orig_lost = proto .connection_lost
148
+ orig_eof_received = proto .eof_received
149
+
150
+ proto .connection_lost = functools .partial (
151
+ connection_lost , orig_lost = orig_lost
152
+ )
153
+ proto .eof_received = functools .partial (
154
+ eof_received , orig_eof_received = orig_eof_received
155
+ )
156
+
157
+ if ssl_transports == 0 :
158
+ all_is_lost .set ()
159
+
160
+ return all_is_lost
161
+
103
162
async def close (self ) -> None :
104
163
"""Coroutine which will close the aiohttp session.
105
164
@@ -108,7 +167,12 @@ async def close(self) -> None:
108
167
when you exit the async context manager.
109
168
"""
110
169
if self .session is not None :
170
+ closed_event = self .create_aiohttp_closed_event (self .session )
111
171
await self .session .close ()
172
+ try :
173
+ await asyncio .wait_for (closed_event .wait (), self .ssl_close_timeout )
174
+ except asyncio .TimeoutError :
175
+ pass
112
176
self .session = None
113
177
114
178
async def execute (
0 commit comments