Skip to content

Commit 11035ec

Browse files
committed
paillat suggestions, in test
1 parent f6d6206 commit 11035ec

File tree

1 file changed

+60
-30
lines changed

1 file changed

+60
-30
lines changed

discord/utils.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ def get(iterable: Iterable[T], **attrs: Any) -> T | None:
600600
bound="VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Member | User | Guild | Role | GuildEmoji | AppEmoji",
601601
)
602602
_D = TypeVar("_D")
603+
_Getter = Callable[[Any, int], Any]
604+
_Fetcher = Callable[[Any, int], Awaitable[Any]]
603605

604606

605607
# TODO: In version 3.0, remove the 'attr' and 'id' arguments.
@@ -681,20 +683,11 @@ async def get_or_fetch(
681683
:exc:`Forbidden`
682684
You do not have permission to fetch the object.
683685
"""
684-
from discord import AppEmoji, Client, Guild, Member, Role, User, abc, emoji
686+
from discord import Client, Guild, Member, Role, User
685687

686688
if object_id is None:
687689
return default if default is not MISSING else None
688690

689-
string_to_type = {
690-
"channel": abc.GuildChannel,
691-
"member": Member,
692-
"user": User,
693-
"guild": Guild,
694-
"emoji": emoji._EmojiTag,
695-
"appemoji": AppEmoji,
696-
"role": Role,
697-
}
698691
# Temporary backward compatibility for 'attr' and 'id'.
699692
# This entire if block should be removed in version 3.0.
700693
if attr is not MISSING or id is not MISSING or isinstance(object_type, str):
@@ -709,7 +702,7 @@ async def get_or_fetch(
709702
deprecated_id = id if id is not MISSING else object_id
710703

711704
if isinstance(deprecated_attr, str):
712-
mapped_type = string_to_type.get(deprecated_attr.lower())
705+
mapped_type = _get_string_to_type_map().get(deprecated_attr.lower())
713706
if mapped_type is None:
714707
raise InvalidArgument(
715708
f"Unknown type string '{deprecated_attr}' used. Please use a valid class like `discord.Member` instead."
@@ -738,7 +731,47 @@ async def get_or_fetch(
738731
elif isinstance(obj, Guild) and object_type is Guild:
739732
raise InvalidArgument("Guild cannot get_or_fetch Guild. Use Client instead.")
740733

741-
getter_fetcher_map = {
734+
try:
735+
getter, fetcher = _get_getter_fetcher_map()[object_type]
736+
except KeyError:
737+
raise InvalidArgument(
738+
f"Class {object_type.__name__} cannot be used with discord.{type(obj).__name__}.get_or_fetch()"
739+
)
740+
741+
result = getter(obj, object_id)
742+
if result is not None:
743+
return result
744+
745+
try:
746+
return await fetcher(obj, object_id)
747+
except (HTTPException, ValueError):
748+
if default is not MISSING:
749+
return default
750+
raise
751+
752+
753+
@functools.lru_cache(maxsize=1)
754+
def _get_string_to_type_map() -> dict[str, type]:
755+
"""Return a cached map of lowercase strings -> discord types."""
756+
from discord import Guild, Member, Role, User, abc, emoji
757+
758+
return {
759+
"channel": abc.GuildChannel,
760+
"member": Member,
761+
"user": User,
762+
"guild": Guild,
763+
"emoji": emoji._EmojiTag,
764+
"appemoji": AppEmoji,
765+
"role": Role,
766+
}
767+
768+
769+
@functools.lru_cache(maxsize=1)
770+
def _get_getter_fetcher_map() -> dict[type, tuple[_Getter, _Fetcher]]:
771+
"""Return a cached map of type names -> (getter, fetcher) functions."""
772+
from discord import Guild, Member, Role, User, abc, emoji
773+
774+
base_map: dict[type, tuple[_Getter, _Fetcher]] = {
742775
Member: (
743776
lambda obj, oid: obj.get_member(oid),
744777
lambda obj, oid: obj.fetch_member(oid),
@@ -764,26 +797,23 @@ async def get_or_fetch(
764797
lambda obj, oid: obj.fetch_channel(oid),
765798
),
766799
}
767-
try:
768-
base_type = next(
769-
base for base in getter_fetcher_map if issubclass(object_type, base)
770-
)
771-
getter, fetcher = getter_fetcher_map[base_type]
772-
except KeyError:
773-
raise InvalidArgument(
774-
f"Class {object_type.__name__} cannot be used with discord.{type(obj).__name__}.get_or_fetch()"
775-
)
776800

777-
result = getter(obj, object_id)
778-
if result is not None:
779-
return result
801+
expanded: dict[type, tuple[_Getter, _Fetcher]] = {}
802+
for base, funcs in base_map.items():
803+
expanded[base] = funcs
804+
for subclass in _all_subclasses(base):
805+
if subclass not in expanded:
806+
expanded[subclass] = funcs
780807

781-
try:
782-
return await fetcher(obj, object_id)
783-
except (HTTPException, ValueError):
784-
if default is not MISSING:
785-
return default
786-
raise
808+
return expanded
809+
810+
811+
def _all_subclasses(cls: type) -> set[type]:
812+
"""Recursively collect all subclasses of a class."""
813+
subs = set(cls.__subclasses__())
814+
for sub in cls.__subclasses__():
815+
subs |= _all_subclasses(sub)
816+
return subs
787817

788818

789819
def _unique(iterable: Iterable[T]) -> list[T]:

0 commit comments

Comments
 (0)