1919import contextvars
2020import heapq
2121import threading
22- from collections .abc import Callable
22+ from collections .abc import Callable , Generator
2323from contextlib import contextmanager
2424from typing import Any , NamedTuple
2525
@@ -55,7 +55,8 @@ def __lt__(self, other: Any) -> bool:
5555
5656
5757@contextmanager
58- def priority_context (priority : int ):
58+ def priority_context (priority : int ) -> Generator [None , None , None ]:
59+ """Set the priority for all PrioritySemaphore use in this context."""
5960 token = _priority .set (priority )
6061 try :
6162 yield None
@@ -67,7 +68,8 @@ def priority_context(priority: int):
6768
6869
6970class PrioritySemaphore :
70- """
71+ """A Semaphore with priority-based aquisition ordering.
72+
7173 Provides a semaphore with similar semantics as asyncio.Semaphore,
7274 but using an underlying priority. priority is shared within a context
7375 manager's logical scope, but the context can be nested safely.
@@ -76,11 +78,11 @@ class PrioritySemaphore:
7678
7779 context manager use:
7880
79- sem = PrioritySemaphore(1)
81+ >>> sem = PrioritySemaphore(1)
82+ >>> with priority_ctx(10):
83+ async with sem:
84+ ...
8085
81- with priority_ctx(10):
82- async with sem:
83- ...
8486 """
8587
8688 _loop : asyncio .AbstractEventLoop | None = None
@@ -93,12 +95,14 @@ def _get_loop(self) -> asyncio.AbstractEventLoop:
9395 if self ._loop is None :
9496 self ._loop = loop
9597 if loop is not self ._loop :
96- raise RuntimeError (f"{ self !r} is bound to a different event loop" )
98+ msg = f"{ self !r} is bound to a different event loop"
99+ raise RuntimeError (msg )
97100 return loop
98101
99102 def __init__ (self , value : int = 1 ):
100103 if value < 0 :
101- raise ValueError ("Semaphore initial value must be >= 0" )
104+ msg = "Semaphore initial value must be >= 0"
105+ raise ValueError (msg )
102106 self ._waiters : list [PriorityWaiter ] | None = None
103107 self ._value : int = value
104108
@@ -120,9 +124,8 @@ def locked(self) -> bool:
120124 async def __aenter__ (self ):
121125 prio = _priority .get ()
122126 await self .acquire (prio )
123- return
124127
125- async def __aexit__ (self , * dont_care : Any ):
128+ async def __aexit__ (self , * dont_care : object ):
126129 self .release ()
127130
128131 async def acquire (self , priority : int = _default ) -> bool :
@@ -174,6 +177,6 @@ def _maybe_wake(self) -> None:
174177 heapq .heappush (self ._waiters , next_waiter )
175178 break
176179
177- def release (self ):
180+ def release (self ) -> None :
178181 self ._value += 1
179182 self ._maybe_wake ()
0 commit comments