1- import asyncio
1+ import functools
22from contextlib import (
33 asynccontextmanager , AbstractAsyncContextManager ,
44 AsyncExitStack , nullcontext , aclosing , contextmanager )
88
99from test .test_contextlib import TestBaseExitStack
1010
11- support .requires_working_socket (module = True )
1211
13- def tearDownModule ():
14- asyncio ._set_event_loop_policy (None )
12+ def _run_async_fn (async_fn , / , * args , ** kwargs ):
13+ coro = async_fn (* args , ** kwargs )
14+ try :
15+ coro .send (None )
16+ except StopIteration as e :
17+ return e .value
18+ else :
19+ raise AssertionError ("coroutine did not complete" )
20+ finally :
21+ coro .close ()
1522
1623
17- class TestAbstractAsyncContextManager (unittest .IsolatedAsyncioTestCase ):
24+ def _async_test (async_fn ):
25+ """Decorator to turn an async function into a synchronous function"""
26+ @functools .wraps (async_fn )
27+ def wrapper (* args , ** kwargs ):
28+ return _run_async_fn (async_fn , * args , ** kwargs )
1829
30+ return wrapper
31+
32+
33+ class TestAbstractAsyncContextManager (unittest .TestCase ):
34+
35+ @_async_test
1936 async def test_enter (self ):
2037 class DefaultEnter (AbstractAsyncContextManager ):
2138 async def __aexit__ (self , * args ):
@@ -27,6 +44,7 @@ async def __aexit__(self, *args):
2744 async with manager as context :
2845 self .assertIs (manager , context )
2946
47+ @_async_test
3048 async def test_slots (self ):
3149 class DefaultAsyncContextManager (AbstractAsyncContextManager ):
3250 __slots__ = ()
@@ -38,6 +56,7 @@ async def __aexit__(self, *args):
3856 manager = DefaultAsyncContextManager ()
3957 manager .var = 42
4058
59+ @_async_test
4160 async def test_async_gen_propagates_generator_exit (self ):
4261 # A regression test for https://bugs.python.org/issue33786.
4362
@@ -88,8 +107,9 @@ class NoneAexit(ManagerFromScratch):
88107 self .assertFalse (issubclass (NoneAexit , AbstractAsyncContextManager ))
89108
90109
91- class AsyncContextManagerTestCase (unittest .IsolatedAsyncioTestCase ):
110+ class AsyncContextManagerTestCase (unittest .TestCase ):
92111
112+ @_async_test
93113 async def test_contextmanager_plain (self ):
94114 state = []
95115 @asynccontextmanager
@@ -103,6 +123,7 @@ async def woohoo():
103123 state .append (x )
104124 self .assertEqual (state , [1 , 42 , 999 ])
105125
126+ @_async_test
106127 async def test_contextmanager_finally (self ):
107128 state = []
108129 @asynccontextmanager
@@ -120,6 +141,7 @@ async def woohoo():
120141 raise ZeroDivisionError ()
121142 self .assertEqual (state , [1 , 42 , 999 ])
122143
144+ @_async_test
123145 async def test_contextmanager_traceback (self ):
124146 @asynccontextmanager
125147 async def f ():
@@ -175,6 +197,7 @@ class StopAsyncIterationSubclass(StopAsyncIteration):
175197 self .assertEqual (frames [0 ].name , 'test_contextmanager_traceback' )
176198 self .assertEqual (frames [0 ].line , 'raise stop_exc' )
177199
200+ @_async_test
178201 async def test_contextmanager_no_reraise (self ):
179202 @asynccontextmanager
180203 async def whee ():
@@ -184,6 +207,7 @@ async def whee():
184207 # Calling __aexit__ should not result in an exception
185208 self .assertFalse (await ctx .__aexit__ (TypeError , TypeError ("foo" ), None ))
186209
210+ @_async_test
187211 async def test_contextmanager_trap_yield_after_throw (self ):
188212 @asynccontextmanager
189213 async def whoo ():
@@ -199,6 +223,7 @@ async def whoo():
199223 # The "gen" attribute is an implementation detail.
200224 self .assertFalse (ctx .gen .ag_suspended )
201225
226+ @_async_test
202227 async def test_contextmanager_trap_no_yield (self ):
203228 @asynccontextmanager
204229 async def whoo ():
@@ -208,6 +233,7 @@ async def whoo():
208233 with self .assertRaises (RuntimeError ):
209234 await ctx .__aenter__ ()
210235
236+ @_async_test
211237 async def test_contextmanager_trap_second_yield (self ):
212238 @asynccontextmanager
213239 async def whoo ():
@@ -221,6 +247,7 @@ async def whoo():
221247 # The "gen" attribute is an implementation detail.
222248 self .assertFalse (ctx .gen .ag_suspended )
223249
250+ @_async_test
224251 async def test_contextmanager_non_normalised (self ):
225252 @asynccontextmanager
226253 async def whoo ():
@@ -234,6 +261,7 @@ async def whoo():
234261 with self .assertRaises (SyntaxError ):
235262 await ctx .__aexit__ (RuntimeError , None , None )
236263
264+ @_async_test
237265 async def test_contextmanager_except (self ):
238266 state = []
239267 @asynccontextmanager
@@ -251,6 +279,7 @@ async def woohoo():
251279 raise ZeroDivisionError (999 )
252280 self .assertEqual (state , [1 , 42 , 999 ])
253281
282+ @_async_test
254283 async def test_contextmanager_except_stopiter (self ):
255284 @asynccontextmanager
256285 async def woohoo ():
@@ -277,6 +306,7 @@ class StopAsyncIterationSubclass(StopAsyncIteration):
277306 else :
278307 self .fail (f'{ stop_exc } was suppressed' )
279308
309+ @_async_test
280310 async def test_contextmanager_wrap_runtimeerror (self ):
281311 @asynccontextmanager
282312 async def woohoo ():
@@ -321,12 +351,14 @@ def test_contextmanager_doc_attrib(self):
321351 self .assertEqual (baz .__doc__ , "Whee!" )
322352
323353 @support .requires_docstrings
354+ @_async_test
324355 async def test_instance_docstring_given_cm_docstring (self ):
325356 baz = self ._create_contextmanager_attribs ()(None )
326357 self .assertEqual (baz .__doc__ , "Whee!" )
327358 async with baz :
328359 pass # suppress warning
329360
361+ @_async_test
330362 async def test_keywords (self ):
331363 # Ensure no keyword arguments are inhibited
332364 @asynccontextmanager
@@ -335,6 +367,7 @@ async def woohoo(self, func, args, kwds):
335367 async with woohoo (self = 11 , func = 22 , args = 33 , kwds = 44 ) as target :
336368 self .assertEqual (target , (11 , 22 , 33 , 44 ))
337369
370+ @_async_test
338371 async def test_recursive (self ):
339372 depth = 0
340373 ncols = 0
@@ -361,6 +394,7 @@ async def recursive():
361394 self .assertEqual (ncols , 10 )
362395 self .assertEqual (depth , 0 )
363396
397+ @_async_test
364398 async def test_decorator (self ):
365399 entered = False
366400
@@ -379,6 +413,7 @@ async def test():
379413 await test ()
380414 self .assertFalse (entered )
381415
416+ @_async_test
382417 async def test_decorator_with_exception (self ):
383418 entered = False
384419
@@ -401,6 +436,7 @@ async def test():
401436 await test ()
402437 self .assertFalse (entered )
403438
439+ @_async_test
404440 async def test_decorating_method (self ):
405441
406442 @asynccontextmanager
@@ -435,14 +471,15 @@ async def method(self, a, b, c=None):
435471 self .assertEqual (test .b , 2 )
436472
437473
438- class AclosingTestCase (unittest .IsolatedAsyncioTestCase ):
474+ class AclosingTestCase (unittest .TestCase ):
439475
440476 @support .requires_docstrings
441477 def test_instance_docs (self ):
442478 cm_docstring = aclosing .__doc__
443479 obj = aclosing (None )
444480 self .assertEqual (obj .__doc__ , cm_docstring )
445481
482+ @_async_test
446483 async def test_aclosing (self ):
447484 state = []
448485 class C :
@@ -454,6 +491,7 @@ async def aclose(self):
454491 self .assertEqual (x , y )
455492 self .assertEqual (state , [1 ])
456493
494+ @_async_test
457495 async def test_aclosing_error (self ):
458496 state = []
459497 class C :
@@ -467,6 +505,7 @@ async def aclose(self):
467505 1 / 0
468506 self .assertEqual (state , [1 ])
469507
508+ @_async_test
470509 async def test_aclosing_bpo41229 (self ):
471510 state = []
472511
@@ -492,45 +531,27 @@ async def agenfunc():
492531 self .assertEqual (state , [1 ])
493532
494533
495- class TestAsyncExitStack (TestBaseExitStack , unittest .IsolatedAsyncioTestCase ):
534+ class TestAsyncExitStack (TestBaseExitStack , unittest .TestCase ):
496535 class SyncAsyncExitStack (AsyncExitStack ):
497- @staticmethod
498- def run_coroutine (coro ):
499- loop = asyncio .new_event_loop ()
500- t = loop .create_task (coro )
501- t .add_done_callback (lambda f : loop .stop ())
502- loop .run_forever ()
503-
504- exc = t .exception ()
505- if not exc :
506- return t .result ()
507- else :
508- context = exc .__context__
509-
510- try :
511- raise exc
512- except :
513- exc .__context__ = context
514- raise exc
515536
516537 def close (self ):
517- return self . run_coroutine (self .aclose () )
538+ return _run_async_fn (self .aclose )
518539
519540 def __enter__ (self ):
520- return self . run_coroutine (self .__aenter__ () )
541+ return _run_async_fn (self .__aenter__ )
521542
522543 def __exit__ (self , * exc_details ):
523- return self . run_coroutine (self .__aexit__ ( * exc_details ) )
544+ return _run_async_fn (self .__aexit__ , * exc_details )
524545
525546 exit_stack = SyncAsyncExitStack
526547 callback_error_internal_frames = [
527- ('__exit__' , 'return self.run_coroutine(self.__aexit__(*exc_details))' ),
528- ('run_coroutine' , 'raise exc' ),
529- ('run_coroutine' , 'raise exc' ),
548+ ('__exit__' , 'return _run_async_fn(self.__aexit__, *exc_details)' ),
549+ ('_run_async_fn' , 'coro.send(None)' ),
530550 ('__aexit__' , 'raise exc' ),
531551 ('__aexit__' , 'cb_suppress = cb(*exc_details)' ),
532552 ]
533553
554+ @_async_test
534555 async def test_async_callback (self ):
535556 expected = [
536557 ((), {}),
@@ -573,6 +594,7 @@ async def _exit(*args, **kwds):
573594 stack .push_async_callback (callback = _exit , arg = 3 )
574595 self .assertEqual (result , [])
575596
597+ @_async_test
576598 async def test_async_push (self ):
577599 exc_raised = ZeroDivisionError
578600 async def _expect_exc (exc_type , exc , exc_tb ):
@@ -608,6 +630,7 @@ async def __aexit__(self, *exc_details):
608630 self .assertIs (stack ._exit_callbacks [- 1 ][1 ], _expect_exc )
609631 1 / 0
610632
633+ @_async_test
611634 async def test_enter_async_context (self ):
612635 class TestCM (object ):
613636 async def __aenter__ (self ):
@@ -629,6 +652,7 @@ async def _exit():
629652
630653 self .assertEqual (result , [1 , 2 , 3 , 4 ])
631654
655+ @_async_test
632656 async def test_enter_async_context_errors (self ):
633657 class LacksEnterAndExit :
634658 pass
@@ -648,6 +672,7 @@ async def __aenter__(self):
648672 await stack .enter_async_context (LacksExit ())
649673 self .assertFalse (stack ._exit_callbacks )
650674
675+ @_async_test
651676 async def test_async_exit_exception_chaining (self ):
652677 # Ensure exception chaining matches the reference behaviour
653678 async def raise_exc (exc ):
@@ -679,6 +704,7 @@ async def suppress_exc(*exc_details):
679704 self .assertIsInstance (inner_exc , ValueError )
680705 self .assertIsInstance (inner_exc .__context__ , ZeroDivisionError )
681706
707+ @_async_test
682708 async def test_async_exit_exception_explicit_none_context (self ):
683709 # Ensure AsyncExitStack chaining matches actual nested `with` statements
684710 # regarding explicit __context__ = None.
@@ -713,6 +739,7 @@ async def my_cm_with_exit_stack():
713739 else :
714740 self .fail ("Expected IndexError, but no exception was raised" )
715741
742+ @_async_test
716743 async def test_instance_bypass_async (self ):
717744 class Example (object ): pass
718745 cm = Example ()
@@ -725,7 +752,8 @@ class Example(object): pass
725752 self .assertIs (stack ._exit_callbacks [- 1 ][1 ], cm )
726753
727754
728- class TestAsyncNullcontext (unittest .IsolatedAsyncioTestCase ):
755+ class TestAsyncNullcontext (unittest .TestCase ):
756+ @_async_test
729757 async def test_async_nullcontext (self ):
730758 class C :
731759 pass
0 commit comments