Skip to content

Commit 71da029

Browse files
author
Eviee Py
committed
Fix argument parsing when annotations are strings. Fix annotation order
1 parent e787c0f commit 71da029

File tree

1 file changed

+108
-10
lines changed

1 file changed

+108
-10
lines changed

twitchio/ext/commands/core.py

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626

2727
import asyncio
2828
import copy
29+
import functools
2930
import inspect
3031
from collections.abc import Callable, Coroutine, Generator
3132
from types import MappingProxyType, UnionType
32-
from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Union, Unpack, overload
33+
from typing import TYPE_CHECKING, Any, Concatenate, Generic, Literal, ParamSpec, TypeAlias, TypeVar, Union, Unpack, overload
3334

35+
import twitchio
3436
from twitchio.utils import MISSING
3537

3638
from .exceptions import *
@@ -71,6 +73,52 @@
7173
VT = TypeVar("VT")
7274

7375

76+
def unwrap_function(function: Callable[..., Any], /) -> Callable[..., Any]:
77+
partial = functools.partial
78+
79+
while True:
80+
if hasattr(function, "__wrapped__"):
81+
function = function.__wrapped__ # type: ignore
82+
elif isinstance(function, partial):
83+
function = function.func
84+
else:
85+
return function
86+
87+
88+
def get_signature_parameters(
89+
function: Callable[..., Any],
90+
globalns: dict[str, Any],
91+
/,
92+
*,
93+
skip_parameters: int | None = None,
94+
) -> dict[str, inspect.Parameter]:
95+
signature = inspect.Signature.from_callable(function)
96+
params: dict[str, inspect.Parameter] = {}
97+
98+
cache: dict[str, Any] = {}
99+
eval_annotation = twitchio.utils.evaluate_annotation
100+
required_params = twitchio.utils.is_inside_class(function) + 1 if skip_parameters is None else skip_parameters
101+
102+
if len(signature.parameters) < required_params:
103+
raise TypeError(f"Command signature requires at least {required_params - 1} parameter(s)")
104+
105+
iterator = iter(signature.parameters.items())
106+
for _ in range(0, required_params):
107+
next(iterator)
108+
109+
for name, parameter in iterator:
110+
annotation = parameter.annotation
111+
112+
if annotation is None:
113+
params[name] = parameter.replace(annotation=type(None))
114+
continue
115+
116+
annotation = eval_annotation(annotation, globalns, globalns, cache)
117+
params[name] = parameter.replace(annotation=annotation)
118+
119+
return params
120+
121+
74122
class CommandErrorPayload:
75123
"""Payload received in the :func:`~twitchio.event_command_error` event.
76124
@@ -107,7 +155,7 @@ def __init__(
107155
**kwargs: Unpack[CommandOptions],
108156
) -> None:
109157
self._name: str = name
110-
self._callback = callback
158+
self.callback = callback
111159
self._aliases: list[str] = kwargs.get("aliases", [])
112160
self._guards: list[Callable[..., bool] | Callable[..., CoroC]] = getattr(self._callback, "__command_guards__", [])
113161

@@ -178,24 +226,78 @@ def callback(self) -> Callable[Concatenate[Component_T, Context, P], Coro] | Cal
178226
"""
179227
return self._callback
180228

229+
@callback.setter
230+
def callback(
231+
self, func: Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro]
232+
) -> None:
233+
self._callback = func
234+
unwrap = unwrap_function(func)
235+
self.module: str = unwrap.__module__
236+
237+
try:
238+
globalns = unwrap.__globals__
239+
except AttributeError:
240+
globalns = {}
241+
242+
self._params: dict[str, inspect.Parameter] = get_signature_parameters(func, globalns)
243+
244+
def _convert_literal_type(
245+
self, context: Context, param: inspect.Parameter, args: tuple[Any, ...], *, raw: str | None
246+
) -> Any:
247+
name: str = param.name
248+
result: Any = MISSING
249+
250+
for arg in reversed(args):
251+
type_: type = type(arg)
252+
base = context.bot._base_converter._DEFAULTS.get(type_)
253+
254+
if base:
255+
try:
256+
result = base(raw)
257+
except Exception:
258+
continue
259+
260+
break
261+
262+
if result not in args:
263+
pretty: str = " | ".join(str(a) for a in args)
264+
raise BadArgument(f'Failed to convert Literal, expected any [{pretty}], got "{raw}".', name=name, value=raw)
265+
266+
return result
267+
181268
async def _do_conversion(self, context: Context, param: inspect.Parameter, *, annotation: Any, raw: str | None) -> Any:
182269
name: str = param.name
183270

184271
if isinstance(annotation, UnionType) or getattr(annotation, "__origin__", None) is Union:
185272
converters = list(annotation.__args__)
186-
converters.remove(type(None))
273+
274+
try:
275+
converters.remove(type(None))
276+
except ValueError:
277+
pass
187278

188279
result: Any = MISSING
189280

190-
for c in converters:
281+
for c in reversed(converters):
191282
try:
192283
result = await self._do_conversion(context, param=param, annotation=c, raw=raw)
193284
except Exception:
194285
continue
195286

196287
if result is MISSING:
197288
raise BadArgument(
198-
f'Failed to convert argument "{name}" with any converter from Union: {converters}. No default value was provided.',
289+
f'Failed to convert argument "{name}" with any converter from Union: {converters}.',
290+
name=name,
291+
value=raw,
292+
)
293+
294+
return result
295+
296+
if getattr(annotation, "__origin__", None) is Literal:
297+
result = self._convert_literal_type(context, param, annotation.__args__, raw=raw)
298+
if result is MISSING:
299+
raise BadArgument(
300+
f"Failed to convert Literal, no converter found for types in {annotation.__args__}",
199301
name=name,
200302
value=raw,
201303
)
@@ -230,11 +332,7 @@ async def _do_conversion(self, context: Context, param: inspect.Parameter, *, an
230332

231333
async def _parse_arguments(self, context: Context) -> ...:
232334
context._view.skip_ws()
233-
signature: inspect.Signature = inspect.signature(self._callback)
234-
235-
# We expect context always and self with commands in components...
236-
skip: int = 2 if self._injected else 1
237-
params: list[inspect.Parameter] = list(signature.parameters.values())[skip:]
335+
params: list[inspect.Parameter] = list(self._params.values())
238336

239337
args: list[Any] = []
240338
kwargs = {}

0 commit comments

Comments
 (0)