|  | 
|  | 1 | +import asyncio | 
|  | 2 | +import unittest | 
|  | 3 | +from threading import Thread | 
|  | 4 | +from unittest import TestCase | 
|  | 5 | + | 
|  | 6 | +from test.support import threading_helper | 
|  | 7 | + | 
|  | 8 | +threading_helper.requires_working_threading(module=True) | 
|  | 9 | + | 
|  | 10 | +def tearDownModule(): | 
|  | 11 | +    asyncio._set_event_loop_policy(None) | 
|  | 12 | + | 
|  | 13 | + | 
|  | 14 | +class TestFreeThreading: | 
|  | 15 | +    def test_all_tasks_race(self) -> None: | 
|  | 16 | +        async def main(): | 
|  | 17 | +            loop = asyncio.get_running_loop() | 
|  | 18 | +            future = loop.create_future() | 
|  | 19 | + | 
|  | 20 | +            async def coro(): | 
|  | 21 | +                await future | 
|  | 22 | + | 
|  | 23 | +            tasks = set() | 
|  | 24 | + | 
|  | 25 | +            async with asyncio.TaskGroup() as tg: | 
|  | 26 | +                for _ in range(100): | 
|  | 27 | +                    tasks.add(tg.create_task(coro())) | 
|  | 28 | + | 
|  | 29 | +                all_tasks = self.all_tasks(loop) | 
|  | 30 | +                self.assertEqual(len(all_tasks), 101) | 
|  | 31 | + | 
|  | 32 | +                for task in all_tasks: | 
|  | 33 | +                    self.assertEqual(task.get_loop(), loop) | 
|  | 34 | +                    self.assertFalse(task.done()) | 
|  | 35 | + | 
|  | 36 | +                current = self.current_task() | 
|  | 37 | +                self.assertEqual(current.get_loop(), loop) | 
|  | 38 | +                self.assertSetEqual(all_tasks, tasks | {current}) | 
|  | 39 | +                future.set_result(None) | 
|  | 40 | + | 
|  | 41 | +        def runner(): | 
|  | 42 | +            with asyncio.Runner() as runner: | 
|  | 43 | +                loop = runner.get_loop() | 
|  | 44 | +                loop.set_task_factory(self.factory) | 
|  | 45 | +                runner.run(main()) | 
|  | 46 | + | 
|  | 47 | +        threads = [] | 
|  | 48 | + | 
|  | 49 | +        for _ in range(10): | 
|  | 50 | +            thread = Thread(target=runner) | 
|  | 51 | +            threads.append(thread) | 
|  | 52 | + | 
|  | 53 | +        with threading_helper.start_threads(threads): | 
|  | 54 | +            pass | 
|  | 55 | + | 
|  | 56 | + | 
|  | 57 | +class TestPyFreeThreading(TestFreeThreading, TestCase): | 
|  | 58 | +    all_tasks = staticmethod(asyncio.tasks._py_all_tasks) | 
|  | 59 | +    current_task = staticmethod(asyncio.tasks._py_current_task) | 
|  | 60 | + | 
|  | 61 | +    def factory(self, loop, coro, context=None): | 
|  | 62 | +        return asyncio.tasks._PyTask(coro, loop=loop, context=context) | 
|  | 63 | + | 
|  | 64 | + | 
|  | 65 | +@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") | 
|  | 66 | +class TestCFreeThreading(TestFreeThreading, TestCase): | 
|  | 67 | +    all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None)) | 
|  | 68 | +    current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None)) | 
|  | 69 | + | 
|  | 70 | +    def factory(self, loop, coro, context=None): | 
|  | 71 | +        return asyncio.tasks._CTask(coro, loop=loop, context=context) | 
|  | 72 | + | 
|  | 73 | + | 
|  | 74 | +class TestEagerPyFreeThreading(TestPyFreeThreading): | 
|  | 75 | +    def factory(self, loop, coro, context=None): | 
|  | 76 | +        return asyncio.tasks._PyTask(coro, loop=loop, context=context, eager_start=True) | 
|  | 77 | + | 
|  | 78 | + | 
|  | 79 | +@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") | 
|  | 80 | +class TestEagerCFreeThreading(TestCFreeThreading, TestCase): | 
|  | 81 | +    def factory(self, loop, coro, context=None): | 
|  | 82 | +        return asyncio.tasks._CTask(coro, loop=loop, context=context, eager_start=True) | 
0 commit comments