Skip to content

Commit 8b1c1d7

Browse files
committed
Add utility to wait for several events.
Indeed, trio has no equivalent of asyncio.wait.
1 parent cb673c3 commit 8b1c1d7

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

src/websockets/trio/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import sys
2+
3+
import trio
4+
5+
6+
if sys.version_info[:2] < (3, 11): # pragma: no cover
7+
from exceptiongroup import BaseExceptionGroup
8+
9+
10+
__all__ = ["race_events"]
11+
12+
13+
# Based on https://trio.readthedocs.io/en/stable/reference-core.html#custom-supervisors
14+
15+
16+
async def jockey(event: trio.Event, cancel_scope: trio.CancelScope) -> None:
17+
await event.wait()
18+
cancel_scope.cancel()
19+
20+
21+
async def race_events(*events: trio.Event) -> None:
22+
"""
23+
Wait for any of the given events to be set.
24+
25+
Args:
26+
*events: The events to wait for.
27+
28+
"""
29+
if not events:
30+
raise ValueError("no events provided")
31+
32+
try:
33+
async with trio.open_nursery() as nursery:
34+
for event in events:
35+
nursery.start_soon(jockey, event, nursery.cancel_scope)
36+
except BaseExceptionGroup as exc:
37+
try:
38+
trio._util.raise_single_exception_from_group(exc)
39+
except trio._util.MultipleExceptionError: # pragma: no cover
40+
raise AssertionError(
41+
"race_events should be canceled; please file a bug report"
42+
) from exc

tests/trio/test_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import trio.testing
2+
3+
from websockets.trio.utils import *
4+
5+
from .utils import IsolatedTrioTestCase
6+
7+
8+
class UtilsTests(IsolatedTrioTestCase):
9+
async def test_race_events(self):
10+
event1 = trio.Event()
11+
event2 = trio.Event()
12+
done = trio.Event()
13+
14+
async def waiter():
15+
await race_events(event1, event2)
16+
done.set()
17+
18+
async with trio.open_nursery() as nursery:
19+
nursery.start_soon(waiter)
20+
await trio.testing.wait_all_tasks_blocked()
21+
self.assertFalse(done.is_set())
22+
23+
event1.set()
24+
await trio.testing.wait_all_tasks_blocked()
25+
self.assertTrue(done.is_set())
26+
27+
async def test_race_events_cancelled(self):
28+
event1 = trio.Event()
29+
event2 = trio.Event()
30+
31+
async def waiter():
32+
with trio.move_on_after(0):
33+
await race_events(event1, event2)
34+
35+
async with trio.open_nursery() as nursery:
36+
nursery.start_soon(waiter)
37+
38+
async def test_race_events_no_events(self):
39+
with self.assertRaises(ValueError):
40+
await race_events()

tests/trio/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import functools
2+
import inspect
3+
import sys
4+
import unittest
5+
6+
import trio.testing
7+
8+
9+
if sys.version_info[:2] < (3, 11): # pragma: no cover
10+
from exceptiongroup import BaseExceptionGroup
11+
12+
13+
class IsolatedTrioTestCase(unittest.TestCase):
14+
"""
15+
Wrap test coroutines with :func:`trio.testing.trio_test` automatically.
16+
17+
Create a nursery for each test, available in the :attr:`nursery` attribute.
18+
19+
:meth:`asyncSetUp` and :meth:`asyncTearDown` are supported, similar to
20+
:class:`unittest.IsolatedAsyncioTestCase`, but ``addAsyncCleanup`` isn't.
21+
22+
"""
23+
24+
def __init_subclass__(cls, **kwargs):
25+
super().__init_subclass__(**kwargs)
26+
for name in unittest.defaultTestLoader.getTestCaseNames(cls):
27+
test = getattr(cls, name)
28+
if getattr(test, "converted_to_trio", False):
29+
return
30+
assert inspect.iscoroutinefunction(test)
31+
setattr(cls, name, cls.convert_to_trio(test))
32+
33+
@staticmethod
34+
def convert_to_trio(test):
35+
@trio.testing.trio_test
36+
@functools.wraps(test)
37+
async def new_test(self, *args, **kwargs):
38+
try:
39+
# Provide a nursery so it's easy to start tasks.
40+
async with trio.open_nursery() as self.nursery:
41+
await self.asyncSetUp()
42+
try:
43+
return await test(self, *args, **kwargs)
44+
finally:
45+
await self.asyncTearDown()
46+
except BaseExceptionGroup as exc:
47+
# Unwrap exceptions like unittest.SkipTest. Multiple exceptions
48+
# could occur is a test fails with multiple errors; this is OK.
49+
try:
50+
trio._util.raise_single_exception_from_group(exc)
51+
except trio._util.MultipleExceptionError: # pragma: no cover
52+
raise
53+
54+
new_test.converted_to_trio = True
55+
return new_test
56+
57+
async def asyncSetUp(self):
58+
pass
59+
60+
async def asyncTearDown(self):
61+
pass

0 commit comments

Comments
 (0)