Skip to content

Commit 7344c48

Browse files
add more thread safety tests
1 parent 03ede5a commit 7344c48

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

Lib/test/test_asyncio/test_free_threading.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
threading_helper.requires_working_threading(module=True)
99

10+
class MyException(Exception):
11+
pass
12+
13+
1014
def tearDownModule():
1115
asyncio._set_event_loop_policy(None)
1216

@@ -53,6 +57,88 @@ def runner():
5357
with threading_helper.start_threads(threads):
5458
pass
5559

60+
def test_run_coroutine_threadsafe(self) -> None:
61+
results = []
62+
63+
def in_thread(loop: asyncio.AbstractEventLoop):
64+
coro = asyncio.sleep(0.1, result=42)
65+
fut = asyncio.run_coroutine_threadsafe(coro, loop)
66+
result = fut.result()
67+
self.assertEqual(result, 42)
68+
results.append(result)
69+
70+
async def main():
71+
loop = asyncio.get_running_loop()
72+
async with asyncio.TaskGroup() as tg:
73+
for _ in range(10):
74+
tg.create_task(asyncio.to_thread(in_thread, loop))
75+
self.assertEqual(results, [42] * 10)
76+
77+
with asyncio.Runner() as r:
78+
loop = r.get_loop()
79+
loop.set_task_factory(self.factory)
80+
r.run(main())
81+
82+
def test_run_coroutine_threadsafe_exception_caught(self) -> None:
83+
exc = MyException("test")
84+
85+
async def coro():
86+
await asyncio.sleep(0.1)
87+
raise exc
88+
89+
def in_thread(loop: asyncio.AbstractEventLoop):
90+
fut = asyncio.run_coroutine_threadsafe(coro(), loop)
91+
self.assertEqual(fut.exception(), exc)
92+
return exc
93+
94+
async def main():
95+
loop = asyncio.get_running_loop()
96+
tasks = []
97+
async with asyncio.TaskGroup() as tg:
98+
for _ in range(10):
99+
task = tg.create_task(asyncio.to_thread(in_thread, loop))
100+
tasks.append(task)
101+
for task in tasks:
102+
self.assertEqual(await task, exc)
103+
104+
with asyncio.Runner() as r:
105+
loop = r.get_loop()
106+
loop.set_task_factory(self.factory)
107+
r.run(main())
108+
109+
def test_run_coroutine_threadsafe_exception_uncaught(self) -> None:
110+
async def coro():
111+
await asyncio.sleep(1)
112+
raise MyException("test")
113+
114+
def in_thread(loop: asyncio.AbstractEventLoop):
115+
fut = asyncio.run_coroutine_threadsafe(coro(), loop)
116+
return fut.result()
117+
118+
async def main():
119+
loop = asyncio.get_running_loop()
120+
tasks = []
121+
try:
122+
async with asyncio.TaskGroup() as tg:
123+
for _ in range(10):
124+
task = tg.create_task(asyncio.to_thread(in_thread, loop))
125+
tasks.append(task)
126+
except ExceptionGroup:
127+
for task in tasks:
128+
try:
129+
await task
130+
except (MyException, asyncio.CancelledError):
131+
pass
132+
else:
133+
self.fail("Task should have raised an exception")
134+
else:
135+
self.fail("TaskGroup should have raised an exception")
136+
137+
with asyncio.Runner() as r:
138+
loop = r.get_loop()
139+
loop.set_task_factory(self.factory)
140+
r.run(main())
141+
56142

57143
class TestPyFreeThreading(TestFreeThreading, TestCase):
58144
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)

0 commit comments

Comments
 (0)