|
26 | 26 |
|
27 | 27 | import asyncio |
28 | 28 | import copy |
| 29 | +import functools |
29 | 30 | import inspect |
30 | 31 | from collections.abc import Callable, Coroutine, Generator |
31 | 32 | from types import MappingProxyType, UnionType |
32 | | -from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Union, Unpack, overload |
| 33 | +from typing import TYPE_CHECKING, Any, Concatenate, Generic, Literal, ParamSpec, TypeAlias, TypeVar, Union, Unpack, overload |
33 | 34 |
|
| 35 | +import twitchio |
34 | 36 | from twitchio.utils import MISSING |
35 | 37 |
|
36 | 38 | from .exceptions import * |
|
71 | 73 | VT = TypeVar("VT") |
72 | 74 |
|
73 | 75 |
|
| 76 | +def unwrap_function(function: Callable[..., Any], /) -> Callable[..., Any]: |
| 77 | + partial = functools.partial |
| 78 | + |
| 79 | + while True: |
| 80 | + if hasattr(function, "__wrapped__"): |
| 81 | + function = function.__wrapped__ # type: ignore |
| 82 | + elif isinstance(function, partial): |
| 83 | + function = function.func |
| 84 | + else: |
| 85 | + return function |
| 86 | + |
| 87 | + |
| 88 | +def get_signature_parameters( |
| 89 | + function: Callable[..., Any], |
| 90 | + globalns: dict[str, Any], |
| 91 | + /, |
| 92 | + *, |
| 93 | + skip_parameters: int | None = None, |
| 94 | +) -> dict[str, inspect.Parameter]: |
| 95 | + signature = inspect.Signature.from_callable(function) |
| 96 | + params: dict[str, inspect.Parameter] = {} |
| 97 | + |
| 98 | + cache: dict[str, Any] = {} |
| 99 | + eval_annotation = twitchio.utils.evaluate_annotation |
| 100 | + required_params = twitchio.utils.is_inside_class(function) + 1 if skip_parameters is None else skip_parameters |
| 101 | + |
| 102 | + if len(signature.parameters) < required_params: |
| 103 | + raise TypeError(f"Command signature requires at least {required_params - 1} parameter(s)") |
| 104 | + |
| 105 | + iterator = iter(signature.parameters.items()) |
| 106 | + for _ in range(0, required_params): |
| 107 | + next(iterator) |
| 108 | + |
| 109 | + for name, parameter in iterator: |
| 110 | + annotation = parameter.annotation |
| 111 | + |
| 112 | + if annotation is None: |
| 113 | + params[name] = parameter.replace(annotation=type(None)) |
| 114 | + continue |
| 115 | + |
| 116 | + annotation = eval_annotation(annotation, globalns, globalns, cache) |
| 117 | + params[name] = parameter.replace(annotation=annotation) |
| 118 | + |
| 119 | + return params |
| 120 | + |
| 121 | + |
74 | 122 | class CommandErrorPayload: |
75 | 123 | """Payload received in the :func:`~twitchio.event_command_error` event. |
76 | 124 |
|
@@ -107,7 +155,7 @@ def __init__( |
107 | 155 | **kwargs: Unpack[CommandOptions], |
108 | 156 | ) -> None: |
109 | 157 | self._name: str = name |
110 | | - self._callback = callback |
| 158 | + self.callback = callback |
111 | 159 | self._aliases: list[str] = kwargs.get("aliases", []) |
112 | 160 | self._guards: list[Callable[..., bool] | Callable[..., CoroC]] = getattr(self._callback, "__command_guards__", []) |
113 | 161 |
|
@@ -178,24 +226,78 @@ def callback(self) -> Callable[Concatenate[Component_T, Context, P], Coro] | Cal |
178 | 226 | """ |
179 | 227 | return self._callback |
180 | 228 |
|
| 229 | + @callback.setter |
| 230 | + def callback( |
| 231 | + self, func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro] |
| 232 | + ) -> None: |
| 233 | + self._callback = func |
| 234 | + unwrap = unwrap_function(func) |
| 235 | + self.module: str = unwrap.__module__ |
| 236 | + |
| 237 | + try: |
| 238 | + globalns = unwrap.__globals__ |
| 239 | + except AttributeError: |
| 240 | + globalns = {} |
| 241 | + |
| 242 | + self._params: dict[str, inspect.Parameter] = get_signature_parameters(func, globalns) |
| 243 | + |
| 244 | + def _convert_literal_type( |
| 245 | + self, context: Context, param: inspect.Parameter, args: tuple[Any, ...], *, raw: str | None |
| 246 | + ) -> Any: |
| 247 | + name: str = param.name |
| 248 | + result: Any = MISSING |
| 249 | + |
| 250 | + for arg in reversed(args): |
| 251 | + type_: type = type(arg) |
| 252 | + base = context.bot._base_converter._DEFAULTS.get(type_) |
| 253 | + |
| 254 | + if base: |
| 255 | + try: |
| 256 | + result = base(raw) |
| 257 | + except Exception: |
| 258 | + continue |
| 259 | + |
| 260 | + break |
| 261 | + |
| 262 | + if result not in args: |
| 263 | + pretty: str = " | ".join(str(a) for a in args) |
| 264 | + raise BadArgument(f'Failed to convert Literal, expected any [{pretty}], got "{raw}".', name=name, value=raw) |
| 265 | + |
| 266 | + return result |
| 267 | + |
181 | 268 | async def _do_conversion(self, context: Context, param: inspect.Parameter, *, annotation: Any, raw: str | None) -> Any: |
182 | 269 | name: str = param.name |
183 | 270 |
|
184 | 271 | if isinstance(annotation, UnionType) or getattr(annotation, "__origin__", None) is Union: |
185 | 272 | converters = list(annotation.__args__) |
186 | | - converters.remove(type(None)) |
| 273 | + |
| 274 | + try: |
| 275 | + converters.remove(type(None)) |
| 276 | + except ValueError: |
| 277 | + pass |
187 | 278 |
|
188 | 279 | result: Any = MISSING |
189 | 280 |
|
190 | | - for c in converters: |
| 281 | + for c in reversed(converters): |
191 | 282 | try: |
192 | 283 | result = await self._do_conversion(context, param=param, annotation=c, raw=raw) |
193 | 284 | except Exception: |
194 | 285 | continue |
195 | 286 |
|
196 | 287 | if result is MISSING: |
197 | 288 | raise BadArgument( |
198 | | - f'Failed to convert argument "{name}" with any converter from Union: {converters}. No default value was provided.', |
| 289 | + f'Failed to convert argument "{name}" with any converter from Union: {converters}.', |
| 290 | + name=name, |
| 291 | + value=raw, |
| 292 | + ) |
| 293 | + |
| 294 | + return result |
| 295 | + |
| 296 | + if getattr(annotation, "__origin__", None) is Literal: |
| 297 | + result = self._convert_literal_type(context, param, annotation.__args__, raw=raw) |
| 298 | + if result is MISSING: |
| 299 | + raise BadArgument( |
| 300 | + f"Failed to convert Literal, no converter found for types in {annotation.__args__}", |
199 | 301 | name=name, |
200 | 302 | value=raw, |
201 | 303 | ) |
@@ -230,11 +332,7 @@ async def _do_conversion(self, context: Context, param: inspect.Parameter, *, an |
230 | 332 |
|
231 | 333 | async def _parse_arguments(self, context: Context) -> ...: |
232 | 334 | context._view.skip_ws() |
233 | | - signature: inspect.Signature = inspect.signature(self._callback) |
234 | | - |
235 | | - # We expect context always and self with commands in components... |
236 | | - skip: int = 2 if self._injected else 1 |
237 | | - params: list[inspect.Parameter] = list(signature.parameters.values())[skip:] |
| 335 | + params: list[inspect.Parameter] = list(self._params.values()) |
238 | 336 |
|
239 | 337 | args: list[Any] = [] |
240 | 338 | kwargs = {} |
|
0 commit comments