11# Adapted with permission from the EdgeDB project;
22# license: PSFL.
33
4+ import weakref
5+ import sys
46import gc
57import asyncio
68import contextvars
@@ -27,7 +29,25 @@ def get_error_types(eg):
2729 return {type (exc ) for exc in eg .exceptions }
2830
2931
30- class TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
32+ def set_gc_state (enabled ):
33+ was_enabled = gc .isenabled ()
34+ if enabled :
35+ gc .enable ()
36+ else :
37+ gc .disable ()
38+ return was_enabled
39+
40+
41+ @contextlib .contextmanager
42+ def disable_gc ():
43+ was_enabled = set_gc_state (enabled = False )
44+ try :
45+ yield
46+ finally :
47+ set_gc_state (enabled = was_enabled )
48+
49+
50+ class BaseTestTaskGroup :
3151
3252 async def test_taskgroup_01 (self ):
3353
@@ -880,6 +900,30 @@ async def coro_fn():
880900 self .assertIsInstance (exc , _Done )
881901 self .assertListEqual (gc .get_referrers (exc ), [])
882902
903+
904+ async def test_exception_refcycles_parent_task_wr (self ):
905+ """Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
906+ tg = asyncio .TaskGroup ()
907+ exc = None
908+
909+ class _Done (Exception ):
910+ pass
911+
912+ async def coro_fn ():
913+ async with tg :
914+ raise _Done
915+
916+ with disable_gc ():
917+ try :
918+ async with asyncio .TaskGroup () as tg2 :
919+ task_wr = weakref .ref (tg2 .create_task (coro_fn ()))
920+ except* _Done as excs :
921+ exc = excs .exceptions [0 ].exceptions [0 ]
922+
923+ self .assertIsNone (task_wr ())
924+ self .assertIsInstance (exc , _Done )
925+ self .assertListEqual (gc .get_referrers (exc ), [])
926+
883927 async def test_exception_refcycles_propagate_cancellation_error (self ):
884928 """Test that TaskGroup deletes propagate_cancellation_error"""
885929 tg = asyncio .TaskGroup ()
@@ -912,6 +956,81 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
912956 self .assertIsNotNone (exc )
913957 self .assertListEqual (gc .get_referrers (exc ), [])
914958
959+ async def test_cancels_task_if_created_during_creation (self ):
960+ # regression test for gh-128550
961+ ran = False
962+ class MyError (Exception ):
963+ pass
964+
965+ exc = None
966+ try :
967+ async with asyncio .TaskGroup () as tg :
968+ async def third_task ():
969+ raise MyError ("third task failed" )
970+
971+ async def second_task ():
972+ nonlocal ran
973+ tg .create_task (third_task ())
974+ with self .assertRaises (asyncio .CancelledError ):
975+ await asyncio .sleep (0 ) # eager tasks cancel here
976+ await asyncio .sleep (0 ) # lazy tasks cancel here
977+ ran = True
978+
979+ tg .create_task (second_task ())
980+ except* MyError as excs :
981+ exc = excs .exceptions [0 ]
982+
983+ self .assertTrue (ran )
984+ self .assertIsInstance (exc , MyError )
985+
986+ async def test_cancellation_does_not_leak_out_of_tg (self ):
987+ class MyError (Exception ):
988+ pass
989+
990+ async def throw_error ():
991+ raise MyError
992+
993+ try :
994+ async with asyncio .TaskGroup () as tg :
995+ tg .create_task (throw_error ())
996+ except* MyError :
997+ pass
998+ else :
999+ self .fail ("should have raised one MyError in group" )
1000+
1001+ # if this test fails this current task will be cancelled
1002+ # outside the task group and inside unittest internals
1003+ # we yield to the event loop with sleep(0) so that
1004+ # cancellation happens here and error is more understandable
1005+ await asyncio .sleep (0 )
1006+
1007+
1008+ if sys .platform == "win32" :
1009+ EventLoop = asyncio .ProactorEventLoop
1010+ else :
1011+ EventLoop = asyncio .SelectorEventLoop
1012+
1013+
1014+ class IsolatedAsyncioTestCase (unittest .IsolatedAsyncioTestCase ):
1015+ loop_factory = None
1016+
1017+ def _setupAsyncioRunner (self ):
1018+ assert self ._asyncioRunner is None , 'asyncio runner is already initialized'
1019+ runner = asyncio .Runner (debug = True , loop_factory = self .loop_factory )
1020+ self ._asyncioRunner = runner
1021+
1022+
1023+ class TestTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1024+ loop_factory = EventLoop
1025+
1026+
1027+ class TestEagerTaskTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1028+ @staticmethod
1029+ def loop_factory ():
1030+ loop = EventLoop ()
1031+ loop .set_task_factory (asyncio .eager_task_factory )
1032+ return loop
1033+
9151034
9161035if __name__ == "__main__" :
9171036 unittest .main ()
0 commit comments