Skip to content

Commit 08dfb4f

Browse files
ChrisLoveringNumerlorMarkKoz
committed
Add lock utils
This includes some additional function utils too. Co-authored-by: Numerlor <[email protected]> Co-authored-by: MarkKoz <[email protected]>
1 parent c7b6140 commit 08dfb4f

File tree

3 files changed

+250
-1
lines changed

3 files changed

+250
-1
lines changed

pydis_core/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
error_handling,
1111
function,
1212
interactions,
13+
lock,
1314
logging,
1415
members,
1516
messages,
@@ -47,6 +48,7 @@ def apply_monkey_patches() -> None:
4748
error_handling,
4849
function,
4950
interactions,
51+
lock,
5052
logging,
5153
members,
5254
messages,

pydis_core/utils/function.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,113 @@
33
from __future__ import annotations
44

55
import functools
6+
import inspect
67
import types
78
import typing
89
from collections.abc import Callable, Sequence, Set
910

10-
__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"]
11+
__all__ = [
12+
"GlobalNameConflictError",
13+
"command_wraps",
14+
"get_arg_value",
15+
"get_arg_value_wrapper",
16+
"get_bound_args",
17+
"update_wrapper_globals",
18+
]
1119

1220

1321
if typing.TYPE_CHECKING:
1422
_P = typing.ParamSpec("_P")
1523
_R = typing.TypeVar("_R")
1624

25+
Argument = int | str
26+
BoundArgs = typing.OrderedDict[str, typing.Any]
27+
Decorator = typing.Callable[[typing.Callable], typing.Callable]
28+
ArgValGetter = typing.Callable[[BoundArgs], typing.Any]
29+
1730

1831
class GlobalNameConflictError(Exception):
1932
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
2033

2134

35+
def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any:
36+
"""
37+
Return a value from `arguments` based on a name or position.
38+
39+
Arguments:
40+
arguments: An ordered mapping of parameter names to argument values.
41+
Returns:
42+
Value from `arguments` based on a name or position.
43+
Raises:
44+
TypeError: `name_or_pos` isn't a str or int.
45+
ValueError: `name_or_pos` does not match any argument.
46+
"""
47+
if isinstance(name_or_pos, int):
48+
# Convert arguments to a tuple to make them indexable.
49+
arg_values = tuple(arguments.items())
50+
arg_pos = name_or_pos
51+
52+
try:
53+
_name, value = arg_values[arg_pos]
54+
return value
55+
except IndexError:
56+
raise ValueError(f"Argument position {arg_pos} is out of bounds.")
57+
elif isinstance(name_or_pos, str):
58+
arg_name = name_or_pos
59+
try:
60+
return arguments[arg_name]
61+
except KeyError:
62+
raise ValueError(f"Argument {arg_name!r} doesn't exist.")
63+
else:
64+
raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")
65+
66+
67+
def get_arg_value_wrapper(
68+
decorator_func: typing.Callable[[ArgValGetter], Decorator],
69+
name_or_pos: Argument,
70+
func: typing.Callable[[typing.Any], typing.Any] | None = None,
71+
) -> Decorator:
72+
"""
73+
Call `decorator_func` with the value of the arg at the given name/position.
74+
75+
Arguments:
76+
decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of
77+
parameter names to argument values of the function it's decorating.
78+
name_or_pos: The name/position of the arg to get the value from.
79+
func: An optional callable which will return a new value given the argument's value.
80+
81+
Returns:
82+
The decorator returned by `decorator_func`.
83+
"""
84+
def wrapper(args: BoundArgs) -> typing.Any:
85+
value = get_arg_value(name_or_pos, args)
86+
if func:
87+
value = func(value)
88+
return value
89+
90+
return decorator_func(wrapper)
91+
92+
93+
def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs:
94+
"""
95+
Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.
96+
97+
Default parameter values are also set.
98+
99+
Args:
100+
args: The arguments to bind to ``func``
101+
kwargs: The keyword arguments to bind to ``func``
102+
func: The function to bind ``args`` and ``kwargs`` to
103+
Returns:
104+
A mapping of parameter names to argument values.
105+
"""
106+
sig = inspect.signature(func)
107+
bound_args = sig.bind(*args, **kwargs)
108+
bound_args.apply_defaults()
109+
110+
return bound_args.arguments
111+
112+
22113
def update_wrapper_globals(
23114
wrapper: Callable[_P, _R],
24115
wrapped: Callable[_P, _R],

pydis_core/utils/lock.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import asyncio
2+
import inspect
3+
import types
4+
from collections import defaultdict
5+
from collections.abc import Awaitable, Callable, Hashable
6+
from functools import partial
7+
from typing import Any
8+
from weakref import WeakValueDictionary
9+
10+
from pydis_core.utils import function
11+
from pydis_core.utils.function import command_wraps
12+
from pydis_core.utils.logging import get_logger
13+
14+
log = get_logger(__name__)
15+
__lock_dicts = defaultdict(WeakValueDictionary)
16+
17+
_IdCallableReturn = Hashable | Awaitable[Hashable]
18+
_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
19+
ResourceId = Hashable | _IdCallable
20+
21+
22+
class LockedResourceError(RuntimeError):
23+
"""
24+
Exception raised when an operation is attempted on a locked resource.
25+
26+
Attributes:
27+
type (str): Name of the locked resource's type
28+
id (typing.Hashable): ID of the locked resource
29+
"""
30+
31+
def __init__(self, resource_type: str, resource_id: Hashable):
32+
self.type = resource_type
33+
self.id = resource_id
34+
35+
super().__init__(
36+
f"Cannot operate on {self.type.lower()} `{self.id}`; "
37+
"it is currently locked and in use by another operation."
38+
)
39+
40+
41+
class SharedEvent:
42+
"""
43+
Context manager managing an internal event exposed through the wait coro.
44+
45+
While any code is executing in this context manager, the underlying event will not be set;
46+
when all of the holders finish the event will be set.
47+
"""
48+
49+
def __init__(self):
50+
self._active_count = 0
51+
self._event = asyncio.Event()
52+
self._event.set()
53+
54+
def __enter__(self):
55+
"""Increment the count of the active holders and clear the internal event."""
56+
self._active_count += 1
57+
self._event.clear()
58+
59+
def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001
60+
"""Decrement the count of the active holders; if 0 is reached set the internal event."""
61+
self._active_count -= 1
62+
if not self._active_count:
63+
self._event.set()
64+
65+
async def wait(self) -> None:
66+
"""Wait for all active holders to exit."""
67+
await self._event.wait()
68+
69+
70+
def lock(
71+
namespace: Hashable,
72+
resource_id: ResourceId,
73+
*,
74+
raise_error: bool = False,
75+
wait: bool = False,
76+
) -> Callable:
77+
"""
78+
Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
79+
80+
If decorating a command, this decorator must go before (below) the `command` decorator.
81+
82+
Arguments:
83+
namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs.
84+
resource_id: identifies a resource on which to perform a mutually exclusive operation.
85+
It may also be a callable or awaitable which will return the resource ID given an ordered
86+
mapping of the parameters' names to arguments' values.
87+
raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired.
88+
wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually
89+
exclusive function currently holds the lock for a resource, do not run the decorated function
90+
and return None.
91+
92+
Raises:
93+
:exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True.
94+
"""
95+
def decorator(func: types.FunctionType) -> types.FunctionType:
96+
name = func.__name__
97+
98+
@command_wraps(func)
99+
async def wrapper(*args, **kwargs) -> Any:
100+
log.trace(f"{name}: mutually exclusive decorator called")
101+
102+
if callable(resource_id):
103+
log.trace(f"{name}: binding args to signature")
104+
bound_args = function.get_bound_args(func, args, kwargs)
105+
106+
log.trace(f"{name}: calling the given callable to get the resource ID")
107+
id_ = resource_id(bound_args)
108+
109+
if inspect.isawaitable(id_):
110+
log.trace(f"{name}: awaiting to get resource ID")
111+
id_ = await id_
112+
else:
113+
id_ = resource_id
114+
115+
log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")
116+
117+
# Get the lock for the ID. Create a lock if one doesn't exist yet.
118+
locks = __lock_dicts[namespace]
119+
lock_ = locks.setdefault(id_, asyncio.Lock())
120+
121+
# It's safe to check an asyncio.Lock is free before acquiring it because:
122+
# 1. Synchronous code like `if not lock_.locked()` does not yield execution
123+
# 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free
124+
# 3. awaits only yield execution to the event loop at actual I/O boundaries
125+
if wait or not lock_.locked():
126+
log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...")
127+
async with lock_:
128+
return await func(*args, **kwargs)
129+
else:
130+
log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
131+
if raise_error:
132+
raise LockedResourceError(str(namespace), id_)
133+
return None
134+
135+
return wrapper
136+
return decorator
137+
138+
139+
def lock_arg(
140+
namespace: Hashable,
141+
name_or_pos: function.Argument,
142+
func: Callable[[Any], _IdCallableReturn] | None = None,
143+
*,
144+
raise_error: bool = False,
145+
wait: bool = False,
146+
) -> Callable:
147+
"""
148+
Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
149+
150+
See `lock` docs for more information.
151+
152+
Arguments:
153+
func: An optional callable or awaitable which will return the ID given the argument value.
154+
"""
155+
decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)
156+
return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)

0 commit comments

Comments
 (0)