11# Adapted with permission from the EdgeDB project;
22# license: PSFL.
33
4+ import weakref
5+ import sys
46import gc
57import asyncio
68import contextvars
@@ -27,7 +29,25 @@ def get_error_types(eg):
2729 return {type (exc ) for exc in eg .exceptions }
2830
2931
30- class TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
32+ def set_gc_state (enabled ):
33+ was_enabled = gc .isenabled ()
34+ if enabled :
35+ gc .enable ()
36+ else :
37+ gc .disable ()
38+ return was_enabled
39+
40+
41+ @contextlib .contextmanager
42+ def disable_gc ():
43+ was_enabled = set_gc_state (enabled = False )
44+ try :
45+ yield
46+ finally :
47+ set_gc_state (enabled = was_enabled )
48+
49+
50+ class BaseTestTaskGroup :
3151
3252 async def test_taskgroup_01 (self ):
3353
@@ -820,8 +840,82 @@ async def test_taskgroup_without_parent_task(self):
820840 coro = asyncio .sleep (0 )
821841 with self .assertRaisesRegex (RuntimeError , "has not been entered" ):
822842 tg .create_task (coro )
823- # We still have to await coro to avoid a warning
824- await coro
843+
844+ async def test_coro_closed_when_tg_closed (self ):
845+ async def run_coro_after_tg_closes ():
846+ async with taskgroups .TaskGroup () as tg :
847+ pass
848+ coro = asyncio .sleep (0 )
849+ with self .assertRaisesRegex (RuntimeError , "is finished" ):
850+ tg .create_task (coro )
851+
852+ await run_coro_after_tg_closes ()
853+
854+ async def test_cancelling_level_preserved (self ):
855+ async def raise_after (t , e ):
856+ await asyncio .sleep (t )
857+ raise e ()
858+
859+ try :
860+ async with asyncio .TaskGroup () as tg :
861+ tg .create_task (raise_after (0.0 , RuntimeError ))
862+ except* RuntimeError :
863+ pass
864+ self .assertEqual (asyncio .current_task ().cancelling (), 0 )
865+
866+ async def test_nested_groups_both_cancelled (self ):
867+ async def raise_after (t , e ):
868+ await asyncio .sleep (t )
869+ raise e ()
870+
871+ try :
872+ async with asyncio .TaskGroup () as outer_tg :
873+ try :
874+ async with asyncio .TaskGroup () as inner_tg :
875+ inner_tg .create_task (raise_after (0 , RuntimeError ))
876+ outer_tg .create_task (raise_after (0 , ValueError ))
877+ except* RuntimeError :
878+ pass
879+ else :
880+ self .fail ("RuntimeError not raised" )
881+ self .assertEqual (asyncio .current_task ().cancelling (), 1 )
882+ except* ValueError :
883+ pass
884+ else :
885+ self .fail ("ValueError not raised" )
886+ self .assertEqual (asyncio .current_task ().cancelling (), 0 )
887+
888+ async def test_error_and_cancel (self ):
889+ event = asyncio .Event ()
890+
891+ async def raise_error ():
892+ event .set ()
893+ await asyncio .sleep (0 )
894+ raise RuntimeError ()
895+
896+ async def inner ():
897+ try :
898+ async with taskgroups .TaskGroup () as tg :
899+ tg .create_task (raise_error ())
900+ await asyncio .sleep (1 )
901+ self .fail ("Sleep in group should have been cancelled" )
902+ except* RuntimeError :
903+ self .assertEqual (asyncio .current_task ().cancelling (), 1 )
904+ self .assertEqual (asyncio .current_task ().cancelling (), 1 )
905+ await asyncio .sleep (1 )
906+ self .fail ("Sleep after group should have been cancelled" )
907+
908+ async def outer ():
909+ t = asyncio .create_task (inner ())
910+ await event .wait ()
911+ self .assertEqual (t .cancelling (), 0 )
912+ t .cancel ()
913+ self .assertEqual (t .cancelling (), 1 )
914+ with self .assertRaises (asyncio .CancelledError ):
915+ await t
916+ self .assertTrue (t .cancelled ())
917+
918+ await outer ()
825919
826920 async def test_exception_refcycles_direct (self ):
827921 """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
@@ -880,6 +974,30 @@ async def coro_fn():
880974 self .assertIsInstance (exc , _Done )
881975 self .assertListEqual (gc .get_referrers (exc ), [])
882976
977+
978+ async def test_exception_refcycles_parent_task_wr (self ):
979+ """Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
980+ tg = asyncio .TaskGroup ()
981+ exc = None
982+
983+ class _Done (Exception ):
984+ pass
985+
986+ async def coro_fn ():
987+ async with tg :
988+ raise _Done
989+
990+ with disable_gc ():
991+ try :
992+ async with asyncio .TaskGroup () as tg2 :
993+ task_wr = weakref .ref (tg2 .create_task (coro_fn ()))
994+ except* _Done as excs :
995+ exc = excs .exceptions [0 ].exceptions [0 ]
996+
997+ self .assertIsNone (task_wr ())
998+ self .assertIsInstance (exc , _Done )
999+ self .assertListEqual (gc .get_referrers (exc ), [])
1000+
8831001 async def test_exception_refcycles_propagate_cancellation_error (self ):
8841002 """Test that TaskGroup deletes propagate_cancellation_error"""
8851003 tg = asyncio .TaskGroup ()
@@ -912,6 +1030,32 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
9121030 self .assertIsNotNone (exc )
9131031 self .assertListEqual (gc .get_referrers (exc ), [])
9141032
1033+ if sys .platform == "win32" :
1034+ EventLoop = asyncio .ProactorEventLoop
1035+ else :
1036+ EventLoop = asyncio .SelectorEventLoop
1037+
1038+
1039+ class IsolatedAsyncioTestCase (unittest .IsolatedAsyncioTestCase ):
1040+ loop_factory = None
1041+
1042+ def _setupAsyncioRunner (self ):
1043+ assert self ._asyncioRunner is None , 'asyncio runner is already initialized'
1044+ runner = asyncio .Runner (debug = True , loop_factory = self .loop_factory )
1045+ self ._asyncioRunner = runner
1046+
1047+
1048+ class TestTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1049+ loop_factory = EventLoop
1050+
1051+
1052+ class TestEagerTaskTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1053+ @staticmethod
1054+ def loop_factory ():
1055+ loop = EventLoop ()
1056+ loop .set_task_factory (asyncio .eager_task_factory )
1057+ return loop
1058+
9151059
9161060if __name__ == "__main__" :
9171061 unittest .main ()
0 commit comments