13
13
import urllib .parse
14
14
from typing import Iterable , List , Optional , Union
15
15
16
+ import outcome
16
17
import trio
17
18
import trio .abc
18
19
from wsproto import ConnectionType , WSConnection
44
45
logger = logging .getLogger ('trio-websocket' )
45
46
46
47
48
+ class TrioWebsocketInternalError (Exception ):
49
+ ...
50
+
51
+
47
52
def _ignore_cancel (exc ):
48
53
return None if isinstance (exc , trio .Cancelled ) else exc
49
54
@@ -125,10 +130,10 @@ async def open_websocket(
125
130
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
126
131
or server rejection (:exc:`ConnectionRejected`) during handshakes.
127
132
'''
128
- async with trio .open_nursery () as new_nursery :
133
+ async def open_connection ( nursery : trio .Nursery ) -> WebSocketConnection :
129
134
try :
130
135
with trio .fail_after (connect_timeout ):
131
- connection = await connect_websocket (new_nursery , host , port ,
136
+ return await connect_websocket (nursery , host , port ,
132
137
resource , use_ssl = use_ssl , subprotocols = subprotocols ,
133
138
extra_headers = extra_headers ,
134
139
message_queue_size = message_queue_size ,
@@ -137,14 +142,59 @@ async def open_websocket(
137
142
raise ConnectionTimeout from None
138
143
except OSError as e :
139
144
raise HandshakeError from e
145
+
146
+ async def close_connection (connection : WebSocketConnection ) -> None :
140
147
try :
141
- yield connection
142
- finally :
143
- try :
144
- with trio .fail_after (disconnect_timeout ):
145
- await connection .aclose ()
146
- except trio .TooSlowError :
147
- raise DisconnectionTimeout from None
148
+ with trio .fail_after (disconnect_timeout ):
149
+ await connection .aclose ()
150
+ except trio .TooSlowError :
151
+ raise DisconnectionTimeout from None
152
+
153
+ connection : WebSocketConnection | None = None
154
+ result2 : outcome .Maybe [None ] | None = None
155
+ user_error = None
156
+
157
+ try :
158
+ async with trio .open_nursery () as new_nursery :
159
+ result = await outcome .acapture (open_connection , new_nursery )
160
+
161
+ if isinstance (result , outcome .Value ):
162
+ connection = result .unwrap ()
163
+ try :
164
+ yield connection
165
+ except BaseException as e :
166
+ user_error = e
167
+ raise
168
+ finally :
169
+ result2 = await outcome .acapture (close_connection , connection )
170
+ # This exception handler should only be entered if:
171
+ # 1. The _reader_task started in connect_websocket raises
172
+ # 2. User code raises an exception
173
+ except BaseExceptionGroup as e :
174
+ # user_error, or exception bubbling up from _reader_task
175
+ if len (e .exceptions ) == 1 :
176
+ raise e .exceptions [0 ]
177
+ # if the group contains two exceptions, one being Cancelled, and the other
178
+ # is user_error => drop Cancelled and raise user_error
179
+ # This Cancelled should only have been able to come from _reader_task
180
+ if (
181
+ len (e .exceptions ) == 2
182
+ and user_error is not None
183
+ and user_error in e .exceptions
184
+ and any (isinstance (exc , trio .Cancelled ) for exc in e .exceptions )
185
+ ):
186
+ raise user_error # pylint: disable=raise-missing-from,raising-bad-type
187
+ raise TrioWebsocketInternalError from e # pragma: no cover
188
+ ## TODO: handle keyboardinterrupt?
189
+
190
+ finally :
191
+ if result2 is not None :
192
+ result2 .unwrap ()
193
+
194
+
195
+ # error setting up, unwrap that exception
196
+ if connection is None :
197
+ result .unwrap ()
148
198
149
199
150
200
async def connect_websocket (nursery , host , port , resource , * , use_ssl ,
0 commit comments