11# Adapted with permission from the EdgeDB project; 
22# license: PSFL. 
33
4+ import  weakref 
5+ import  sys 
46import  gc 
57import  asyncio 
68import  contextvars 
@@ -27,7 +29,25 @@ def get_error_types(eg):
2729    return  {type (exc ) for  exc  in  eg .exceptions }
2830
2931
30- class  TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
32+ def  set_gc_state (enabled ):
33+     was_enabled  =  gc .isenabled ()
34+     if  enabled :
35+         gc .enable ()
36+     else :
37+         gc .disable ()
38+     return  was_enabled 
39+ 
40+ 
41+ @contextlib .contextmanager  
42+ def  disable_gc ():
43+     was_enabled  =  set_gc_state (enabled = False )
44+     try :
45+         yield 
46+     finally :
47+         set_gc_state (enabled = was_enabled )
48+ 
49+ 
50+ class  BaseTestTaskGroup :
3151
3252    async  def  test_taskgroup_01 (self ):
3353
@@ -880,6 +900,30 @@ async def coro_fn():
880900        self .assertIsInstance (exc , _Done )
881901        self .assertListEqual (gc .get_referrers (exc ), [])
882902
903+ 
904+     async  def  test_exception_refcycles_parent_task_wr (self ):
905+         """Test that TaskGroup deletes self._parent_task and create_task() deletes task""" 
906+         tg  =  asyncio .TaskGroup ()
907+         exc  =  None 
908+ 
909+         class  _Done (Exception ):
910+             pass 
911+ 
912+         async  def  coro_fn ():
913+             async  with  tg :
914+                 raise  _Done 
915+ 
916+         with  disable_gc ():
917+             try :
918+                 async  with  asyncio .TaskGroup () as  tg2 :
919+                     task_wr  =  weakref .ref (tg2 .create_task (coro_fn ()))
920+             except* _Done  as  excs :
921+                 exc  =  excs .exceptions [0 ].exceptions [0 ]
922+ 
923+         self .assertIsNone (task_wr ())
924+         self .assertIsInstance (exc , _Done )
925+         self .assertListEqual (gc .get_referrers (exc ), [])
926+ 
883927    async  def  test_exception_refcycles_propagate_cancellation_error (self ):
884928        """Test that TaskGroup deletes propagate_cancellation_error""" 
885929        tg  =  asyncio .TaskGroup ()
@@ -912,6 +956,81 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
912956        self .assertIsNotNone (exc )
913957        self .assertListEqual (gc .get_referrers (exc ), [])
914958
959+     async  def  test_cancels_task_if_created_during_creation (self ):
960+         # regression test for gh-128550 
961+         ran  =  False 
962+         class  MyError (Exception ):
963+             pass 
964+ 
965+         exc  =  None 
966+         try :
967+             async  with  asyncio .TaskGroup () as  tg :
968+                 async  def  third_task ():
969+                     raise  MyError ("third task failed" )
970+ 
971+                 async  def  second_task ():
972+                     nonlocal  ran 
973+                     tg .create_task (third_task ())
974+                     with  self .assertRaises (asyncio .CancelledError ):
975+                         await  asyncio .sleep (0 )  # eager tasks cancel here 
976+                         await  asyncio .sleep (0 )  # lazy tasks cancel here 
977+                     ran  =  True 
978+ 
979+                 tg .create_task (second_task ())
980+         except* MyError  as  excs :
981+             exc  =  excs .exceptions [0 ]
982+ 
983+         self .assertTrue (ran )
984+         self .assertIsInstance (exc , MyError )
985+ 
986+     async  def  test_cancellation_does_not_leak_out_of_tg (self ):
987+         class  MyError (Exception ):
988+             pass 
989+ 
990+         async  def  throw_error ():
991+             raise  MyError 
992+ 
993+         try :
994+             async  with  asyncio .TaskGroup () as  tg :
995+                 tg .create_task (throw_error ())
996+         except* MyError :
997+             pass 
998+         else :
999+             self .fail ("should have raised one MyError in group" )
1000+ 
1001+         # if this test fails this current task will be cancelled 
1002+         # outside the task group and inside unittest internals 
1003+         # we yield to the event loop with sleep(0) so that 
1004+         # cancellation happens here and error is more understandable 
1005+         await  asyncio .sleep (0 )
1006+ 
1007+ 
1008+ if  sys .platform  ==  "win32" :
1009+     EventLoop  =  asyncio .ProactorEventLoop 
1010+ else :
1011+     EventLoop  =  asyncio .SelectorEventLoop 
1012+ 
1013+ 
1014+ class  IsolatedAsyncioTestCase (unittest .IsolatedAsyncioTestCase ):
1015+     loop_factory  =  None 
1016+ 
1017+     def  _setupAsyncioRunner (self ):
1018+         assert  self ._asyncioRunner  is  None , 'asyncio runner is already initialized' 
1019+         runner  =  asyncio .Runner (debug = True , loop_factory = self .loop_factory )
1020+         self ._asyncioRunner  =  runner 
1021+ 
1022+ 
1023+ class  TestTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1024+     loop_factory  =  EventLoop 
1025+ 
1026+ 
1027+ class  TestEagerTaskTaskGroup (BaseTestTaskGroup , IsolatedAsyncioTestCase ):
1028+     @staticmethod  
1029+     def  loop_factory ():
1030+         loop  =  EventLoop ()
1031+         loop .set_task_factory (asyncio .eager_task_factory )
1032+         return  loop 
1033+ 
9151034
9161035if  __name__  ==  "__main__" :
9171036    unittest .main ()
0 commit comments