Skip to content

Commit 7eb779b

Browse files
committed
handle strict_exception_groups=True by unwrapping user exceptions from within exceptiongroups
revert making close_connection CS shielded, as that would be a behaviour change causing very long stalls with the default timeout of 60s add comment for pylint disable move RaisesGroup import
1 parent 8f04a5c commit 7eb779b

File tree

2 files changed

+89
-10
lines changed

2 files changed

+89
-10
lines changed

tests/test_connection.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from __future__ import annotations
3333

3434
from functools import partial, wraps
35+
import re
3536
import ssl
37+
import sys
3638
from unittest.mock import patch
3739

3840
import attr
@@ -48,6 +50,13 @@
4850
except ImportError:
4951
from trio.hazmat import current_task # type: ignore # pylint: disable=ungrouped-imports
5052

53+
54+
# only available on trio>=0.25, we don't use it when testing lower versions
55+
try:
56+
from trio.testing import RaisesGroup
57+
except ImportError:
58+
pass
59+
5160
from trio_websocket import (
5261
connect_websocket,
5362
connect_websocket_url,
@@ -66,6 +75,9 @@
6675
wrap_server_stream
6776
)
6877

78+
if sys.version_info < (3, 11):
79+
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin
80+
6981
WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.')))
7082

7183
HOST = '127.0.0.1'
@@ -427,6 +439,9 @@ async def handler(request):
427439
assert header_key == b'x-test-header'
428440
assert header_value == b'My test header'
429441

442+
def _trio_default_loose() -> bool:
443+
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
444+
return int(trio.__version__[2:4]) < 25
430445

431446
@fail_after(1)
432447
async def test_handshake_exception_before_accept() -> None:
@@ -436,14 +451,28 @@ async def test_handshake_exception_before_accept() -> None:
436451
async def handler(request):
437452
raise ValueError()
438453

439-
with pytest.raises(ValueError):
454+
# pylint fails to resolve that BaseExceptionGroup will always be available
455+
with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment
440456
async with trio.open_nursery() as nursery:
441457
server = await nursery.start(serve_websocket, handler, HOST, 0,
442458
None)
443459
async with open_websocket(HOST, server.port, RESOURCE,
444460
use_ssl=False):
445461
pass
446462

463+
if _trio_default_loose():
464+
assert isinstance(exc.value, ValueError)
465+
else:
466+
# there's 4 levels of nurseries opened, leading to 4 nested groups:
467+
# 1. this test
468+
# 2. WebSocketServer.run
469+
# 3. trio.serve_listeners
470+
# 4. WebSocketServer._handle_connection
471+
assert RaisesGroup(
472+
RaisesGroup(
473+
RaisesGroup(
474+
RaisesGroup(ValueError)))).matches(exc.value)
475+
447476

448477
@fail_after(1)
449478
async def test_reject_handshake(nursery):

trio_websocket/_impl.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import urllib.parse
1414
from typing import Iterable, List, Optional, Union
1515

16+
import outcome
1617
import trio
1718
import trio.abc
1819
from wsproto import ConnectionType, WSConnection
@@ -44,6 +45,10 @@
4445
logger = logging.getLogger('trio-websocket')
4546

4647

48+
class TrioWebsocketInternalError(Exception):
49+
...
50+
51+
4752
def _ignore_cancel(exc):
4853
return None if isinstance(exc, trio.Cancelled) else exc
4954

@@ -125,10 +130,10 @@ async def open_websocket(
125130
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
126131
or server rejection (:exc:`ConnectionRejected`) during handshakes.
127132
'''
128-
async with trio.open_nursery() as new_nursery:
133+
async def open_connection(nursery: trio.Nursery) -> WebSocketConnection:
129134
try:
130135
with trio.fail_after(connect_timeout):
131-
connection = await connect_websocket(new_nursery, host, port,
136+
return await connect_websocket(nursery, host, port,
132137
resource, use_ssl=use_ssl, subprotocols=subprotocols,
133138
extra_headers=extra_headers,
134139
message_queue_size=message_queue_size,
@@ -137,14 +142,59 @@ async def open_websocket(
137142
raise ConnectionTimeout from None
138143
except OSError as e:
139144
raise HandshakeError from e
145+
146+
async def close_connection(connection: WebSocketConnection) -> None:
140147
try:
141-
yield connection
142-
finally:
143-
try:
144-
with trio.fail_after(disconnect_timeout):
145-
await connection.aclose()
146-
except trio.TooSlowError:
147-
raise DisconnectionTimeout from None
148+
with trio.fail_after(disconnect_timeout):
149+
await connection.aclose()
150+
except trio.TooSlowError:
151+
raise DisconnectionTimeout from None
152+
153+
connection: WebSocketConnection|None=None
154+
result2: outcome.Maybe[None] | None = None
155+
user_error = None
156+
157+
try:
158+
async with trio.open_nursery() as new_nursery:
159+
result = await outcome.acapture(open_connection, new_nursery)
160+
161+
if isinstance(result, outcome.Value):
162+
connection = result.unwrap()
163+
try:
164+
yield connection
165+
except BaseException as e:
166+
user_error = e
167+
raise
168+
finally:
169+
result2 = await outcome.acapture(close_connection, connection)
170+
# This exception handler should only be entered if:
171+
# 1. The _reader_task started in connect_websocket raises
172+
# 2. User code raises an exception
173+
except BaseExceptionGroup as e:
174+
# user_error, or exception bubbling up from _reader_task
175+
if len(e.exceptions) == 1:
176+
raise e.exceptions[0]
177+
# if the group contains two exceptions, one being Cancelled, and the other
178+
# is user_error => drop Cancelled and raise user_error
179+
# This Cancelled should only have been able to come from _reader_task
180+
if (
181+
len(e.exceptions) == 2
182+
and user_error is not None
183+
and user_error in e.exceptions
184+
and any(isinstance(exc, trio.Cancelled) for exc in e.exceptions)
185+
):
186+
raise user_error # pylint: disable=raise-missing-from,raising-bad-type
187+
raise TrioWebsocketInternalError from e # pragma: no cover
188+
## TODO: handle keyboardinterrupt?
189+
190+
finally:
191+
if result2 is not None:
192+
result2.unwrap()
193+
194+
195+
# error setting up, unwrap that exception
196+
if connection is None:
197+
result.unwrap()
148198

149199

150200
async def connect_websocket(nursery, host, port, resource, *, use_ssl,

0 commit comments

Comments
 (0)