Skip to content

Commit 84ff9df

Browse files
committed
:refactor: reorganize utility function imports and move evaluate_annotation to private module
1 parent 0c6270a commit 84ff9df

File tree

4 files changed

+109
-110
lines changed

4 files changed

+109
-110
lines changed

discord/ext/commands/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444

4545
import discord
46+
from discord.utils.private import evaluate_annotation
4647
from discord import utils
4748
from discord.utils import Undefined
4849

@@ -138,7 +139,6 @@ def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, A
138139
signature = inspect.signature(function)
139140
params = {}
140141
cache: dict[str, Any] = {}
141-
eval_annotation = discord.utils.evaluate_annotation
142142
for name, parameter in signature.parameters.items():
143143
annotation = parameter.annotation
144144
if annotation is parameter.empty:
@@ -148,7 +148,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, A
148148
params[name] = parameter.replace(annotation=type(None))
149149
continue
150150

151-
annotation = eval_annotation(annotation, globalns, globalns, cache)
151+
annotation = evaluate_annotation(annotation, globalns, globalns, cache)
152152
if annotation is Greedy:
153153
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
154154

discord/ext/commands/flags.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from typing import TYPE_CHECKING, Any, Iterator, Literal, Pattern, TypeVar, Union
3333

3434
from discord import utils
35-
from discord.utils import MISSING, Undefined, maybe_coroutine, resolve_annotation
35+
from discord.utils import MISSING, Undefined, maybe_coroutine
36+
from ...utils.private import resolve_annotation
3637

3738
from .converter import run_converters
3839
from .errors import (

discord/utils/__init__.py

Lines changed: 6 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
import datetime
3232
import json
3333
import re
34-
import sys
35-
import types
3634
from bisect import bisect_left
37-
from enum import Enum, auto
3835
from inspect import isawaitable as _isawaitable
3936
from inspect import signature as _signature
4037
from operator import attrgetter
@@ -44,7 +41,6 @@
4441
AsyncIterator,
4542
Callable,
4643
Coroutine,
47-
ForwardRef,
4844
Generic,
4945
Iterable,
5046
Iterator,
@@ -111,12 +107,12 @@ def __get__(self, instance, owner):
111107
if TYPE_CHECKING:
112108
from typing_extensions import ParamSpec
113109

114-
from .abc import Snowflake
115-
from .commands.context import AutocompleteContext
116-
from .commands.options import OptionChoice
117-
from .invite import Invite
118-
from .permissions import Permissions
119-
from .template import Template
110+
from ..abc import Snowflake
111+
from ..commands.context import AutocompleteContext
112+
from ..commands.options import OptionChoice
113+
from ..invite import Invite
114+
from ..permissions import Permissions
115+
from ..template import Template
120116

121117
class _RequestLike(Protocol):
122118
headers: Mapping[str, Any]
@@ -728,102 +724,6 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[list[T]]:
728724
return _chunk(iterator, max_size)
729725

730726

731-
PY_310 = sys.version_info >= (3, 10)
732-
733-
734-
def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
735-
params = []
736-
literal_cls = type(Literal[0])
737-
for p in parameters:
738-
if isinstance(p, literal_cls):
739-
params.extend(p.__args__)
740-
else:
741-
params.append(p)
742-
return tuple(params)
743-
744-
745-
def normalise_optional_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
746-
none_cls = type(None)
747-
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
748-
749-
750-
def evaluate_annotation(
751-
tp: Any,
752-
globals: dict[str, Any],
753-
locals: dict[str, Any],
754-
cache: dict[str, Any],
755-
*,
756-
implicit_str: bool = True,
757-
):
758-
if isinstance(tp, ForwardRef):
759-
tp = tp.__forward_arg__
760-
# ForwardRefs always evaluate their internals
761-
implicit_str = True
762-
763-
if implicit_str and isinstance(tp, str):
764-
if tp in cache:
765-
return cache[tp]
766-
evaluated = eval(tp, globals, locals)
767-
cache[tp] = evaluated
768-
return evaluate_annotation(evaluated, globals, locals, cache)
769-
770-
if hasattr(tp, "__args__"):
771-
implicit_str = True
772-
is_literal = False
773-
args = tp.__args__
774-
if not hasattr(tp, "__origin__"):
775-
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
776-
converted = Union[args] # type: ignore
777-
return evaluate_annotation(converted, globals, locals, cache)
778-
779-
return tp
780-
if tp.__origin__ is Union:
781-
try:
782-
if args.index(type(None)) != len(args) - 1:
783-
args = normalise_optional_params(tp.__args__)
784-
except ValueError:
785-
pass
786-
if tp.__origin__ is Literal:
787-
if not PY_310:
788-
args = flatten_literal_params(tp.__args__)
789-
implicit_str = False
790-
is_literal = True
791-
792-
evaluated_args = tuple(
793-
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
794-
)
795-
796-
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
797-
raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")
798-
799-
if evaluated_args == args:
800-
return tp
801-
802-
try:
803-
return tp.copy_with(evaluated_args)
804-
except AttributeError:
805-
return tp.__origin__[evaluated_args]
806-
807-
return tp
808-
809-
810-
def resolve_annotation(
811-
annotation: Any,
812-
globalns: dict[str, Any],
813-
localns: dict[str, Any] | None,
814-
cache: dict[str, Any] | None,
815-
) -> Any:
816-
if annotation is None:
817-
return type(None)
818-
if isinstance(annotation, str):
819-
annotation = ForwardRef(annotation)
820-
821-
locals = globalns if localns is None else localns
822-
if cache is None:
823-
cache = {}
824-
return evaluate_annotation(annotation, globalns, locals, cache)
825-
826-
827727
TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]
828728

829729

discord/utils/private.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import datetime
44
import functools
55
import re
6+
import sys
7+
import types
68
import unicodedata
79
import warnings
810
from base64 import b64encode
9-
from typing import TYPE_CHECKING, Any, overload, Callable, TypeVar, ParamSpec
11+
from typing import TYPE_CHECKING, Any, overload, Callable, TypeVar, ParamSpec, Iterable, Literal, ForwardRef, Union
1012

1113
from ..errors import InvalidArgument
1214

@@ -260,3 +262,99 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
260262
return decorated
261263

262264
return actual_decorator
265+
266+
267+
PY_310 = sys.version_info >= (3, 10)
268+
269+
270+
def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
271+
params = []
272+
literal_cls = type(Literal[0])
273+
for p in parameters:
274+
if isinstance(p, literal_cls):
275+
params.extend(p.__args__)
276+
else:
277+
params.append(p)
278+
return tuple(params)
279+
280+
281+
def normalise_optional_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
282+
none_cls = type(None)
283+
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
284+
285+
286+
def evaluate_annotation(
287+
tp: Any,
288+
globals: dict[str, Any],
289+
locals: dict[str, Any],
290+
cache: dict[str, Any],
291+
*,
292+
implicit_str: bool = True,
293+
):
294+
if isinstance(tp, ForwardRef):
295+
tp = tp.__forward_arg__
296+
# ForwardRefs always evaluate their internals
297+
implicit_str = True
298+
299+
if implicit_str and isinstance(tp, str):
300+
if tp in cache:
301+
return cache[tp]
302+
evaluated = eval(tp, globals, locals)
303+
cache[tp] = evaluated
304+
return evaluate_annotation(evaluated, globals, locals, cache)
305+
306+
if hasattr(tp, "__args__"):
307+
implicit_str = True
308+
is_literal = False
309+
args = tp.__args__
310+
if not hasattr(tp, "__origin__"):
311+
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
312+
converted = Union[args] # type: ignore
313+
return evaluate_annotation(converted, globals, locals, cache)
314+
315+
return tp
316+
if tp.__origin__ is Union:
317+
try:
318+
if args.index(type(None)) != len(args) - 1:
319+
args = normalise_optional_params(tp.__args__)
320+
except ValueError:
321+
pass
322+
if tp.__origin__ is Literal:
323+
if not PY_310:
324+
args = flatten_literal_params(tp.__args__)
325+
implicit_str = False
326+
is_literal = True
327+
328+
evaluated_args = tuple(
329+
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
330+
)
331+
332+
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
333+
raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")
334+
335+
if evaluated_args == args:
336+
return tp
337+
338+
try:
339+
return tp.copy_with(evaluated_args)
340+
except AttributeError:
341+
return tp.__origin__[evaluated_args]
342+
343+
return tp
344+
345+
346+
def resolve_annotation(
347+
annotation: Any,
348+
globalns: dict[str, Any],
349+
localns: dict[str, Any] | None,
350+
cache: dict[str, Any] | None,
351+
) -> Any:
352+
if annotation is None:
353+
return type(None)
354+
if isinstance(annotation, str):
355+
annotation = ForwardRef(annotation)
356+
357+
locals = globalns if localns is None else localns
358+
if cache is None:
359+
cache = {}
360+
return evaluate_annotation(annotation, globalns, locals, cache)

0 commit comments

Comments
 (0)