1515from functools import partial
1616import sys
1717
18- from ._typing import Protocol , AsyncContextManager , ContextManager , T
18+ from ._typing import Protocol , AsyncContextManager , ContextManager , T , C
1919from ._core import awaitify
2020from ._utility import public_module , slot_get as _slot_get
2121
2222
23+ AnyContextManager = Union [AsyncContextManager [T ], ContextManager [T ]]
24+
25+
2326# typing.AsyncContextManager uses contextlib.AbstractAsyncContextManager if available,
2427# and a custom implementation otherwise. No need to replicate it.
2528AbstractContextManager = AsyncContextManager
2629
2730
2831class ACloseable (Protocol ):
29- async def aclose (self ):
32+ async def aclose (self ) -> None :
3033 """Asynchronously close this object"""
3134
3235
@@ -58,29 +61,31 @@ async def Context(*args, **kwargs):
5861 """
5962
6063 @wraps (func )
61- def helper (* args , ** kwds ) :
64+ def helper (* args : Any , ** kwds : Any ) -> AsyncContextManager [ T ] :
6265 return _AsyncGeneratorContextManager (func , args , kwds )
6366
6467 return helper
6568
6669
67- class _AsyncGeneratorContextManager :
68- def __init__ (self , func , args , kwds ):
70+ class _AsyncGeneratorContextManager (Generic [T ]):
71+ def __init__ (
72+ self , func : Callable [..., AsyncGenerator [T , None ]], args : Any , kwds : Any
73+ ):
6974 self .gen = func (* args , ** kwds )
7075 self .__doc__ = getattr (func , "__doc__" , type (self ).__doc__ )
7176
72- async def __aenter__ (self ):
77+ async def __aenter__ (self ) -> T :
7378 try :
7479 return await self .gen .__anext__ ()
7580 except StopAsyncIteration :
7681 raise RuntimeError ("generator did not yield to __aenter__" ) from None
7782
78- async def __aexit__ (self , exc_type , exc_val , exc_tb ) :
83+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> bool :
7984 if exc_type is None :
8085 try :
8186 await self .gen .__anext__ ()
8287 except StopAsyncIteration :
83- return
88+ return False
8489 else :
8590 raise RuntimeError ("generator did not stop after __aexit__" )
8691 else :
@@ -99,6 +104,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
99104 except exc_type as exc :
100105 if exc is not exc_val :
101106 raise
107+ return False
102108 else :
103109 raise RuntimeError ("generator did not stop after throw() in __aexit__" )
104110
@@ -134,8 +140,9 @@ def __init__(self, thing: AC):
134140 async def __aenter__ (self ) -> AC :
135141 return self .thing
136142
137- async def __aexit__ (self , exc_type , exc_val , exc_tb ) :
143+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> bool :
138144 await self .thing .aclose ()
145+ return False
139146
140147
141148closing = Closing
@@ -175,7 +182,7 @@ def __init__(self: "NullContext[None]", enter_result: None = ...) -> None:
175182 def __init__ (self : "NullContext[T]" , enter_result : T ) -> None :
176183 ...
177184
178- def __init__ (self , enter_result = None ):
185+ def __init__ (self , enter_result : Optional [ T ] = None ):
179186 self .enter_result = enter_result
180187
181188 @overload
@@ -186,11 +193,11 @@ async def __aenter__(self: "NullContext[None]") -> None:
186193 async def __aenter__ (self : "NullContext[T]" ) -> T :
187194 ...
188195
189- async def __aenter__ (self ):
196+ async def __aenter__ (self ) -> Optional [ T ] :
190197 return self .enter_result
191198
192- async def __aexit__ (self , exc_type , exc_val , exc_tb ) :
193- pass
199+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> bool :
200+ return False
194201
195202
196203nullcontext = NullContext
@@ -199,8 +206,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
199206SE = TypeVar (
200207 "SE" ,
201208 bound = Union [
202- AsyncContextManager ,
203- ContextManager ,
209+ AsyncContextManager [ Any ] ,
210+ ContextManager [ Any ] ,
204211 Callable [[Any , BaseException , Any ], Optional [bool ]],
205212 Callable [[Any , BaseException , Any ], Awaitable [Optional [bool ]]],
206213 ],
@@ -228,11 +235,13 @@ class ExitStack:
228235 There are no separate methods to distinguish async and regular arguments.
229236 """
230237
231- def __init__ (self ):
238+ def __init__ (self ) -> None :
232239 self ._exit_callbacks : Deque [Callable [..., Awaitable [Optional [bool ]]]] = deque ()
233240
234241 @staticmethod
235- async def _aexit_callback (callback , exc_type , exc_val , tb ):
242+ async def _aexit_callback (
243+ callback : Callable [[], Awaitable [Any ]], exc_type : Any , exc_val : Any , tb : Any
244+ ) -> bool :
236245 """Invoke a callback as if it were an ``__aexit__`` method"""
237246 await callback ()
238247 return False # callbacks never suppress exceptions
@@ -298,7 +307,7 @@ def push(self, exit: SE) -> SE:
298307 self ._exit_callbacks .append (aexit )
299308 return exit
300309
301- def callback (self , callback : Callable , * args , ** kwargs ) :
310+ def callback (self , callback : C , * args : Any , ** kwargs : Any ) -> C :
302311 """
303312 Registers an arbitrary callback to be called with arguments on unwinding
304313
@@ -312,11 +321,11 @@ def callback(self, callback: Callable, *args, **kwargs):
312321 This method does not change its argument, and can be used as a context manager.
313322 """
314323 self ._exit_callbacks .append (
315- partial (self ._aexit_callback , awaitify ( partial (callback , * args , ** kwargs ) ))
324+ partial (self ._aexit_callback , partial (awaitify ( callback ) , * args , ** kwargs ))
316325 )
317326 return callback
318327
319- async def enter_context (self , cm : AsyncContextManager ) :
328+ async def enter_context (self , cm : AnyContextManager [ T ]) -> T :
320329 """
321330 Enter the supplied context manager, and register it for exit if successful
322331
@@ -353,9 +362,9 @@ async def enter_context(self, cm: AsyncContextManager):
353362 else :
354363 context_value = await _slot_get (cm , "__aenter__" )()
355364 self ._exit_callbacks .append (aexit )
356- return context_value
365+ return context_value # type: ignore
357366
358- async def aclose (self ):
367+ async def aclose (self ) -> None :
359368 """
360369 Immediately unwind the context stack
361370
@@ -371,7 +380,7 @@ def _stitch_context(
371380 exception : BaseException ,
372381 context : BaseException ,
373382 base_context : Optional [BaseException ],
374- ):
383+ ) -> None :
375384 """
376385 Emulate that `exception` was caused by an unhandled `context`
377386
@@ -392,10 +401,10 @@ def _stitch_context(
392401 # we expect it to reference
393402 exception .__context__ = context
394403
395- async def __aenter__ (self ):
404+ async def __aenter__ (self ) -> "ExitStack" :
396405 return self
397406
398- async def __aexit__ (self , exc_type , exc_val , tb ) :
407+ async def __aexit__ (self , exc_type : Any , exc_val : Any , tb : Any ) -> bool :
399408 received_exc = exc_type is not None
400409 # Even if we don't handle an exception *right now*, we may be part
401410 # of an exception handler unwinding gracefully. This is our __context__.
0 commit comments