@@ -600,6 +600,8 @@ def get(iterable: Iterable[T], **attrs: Any) -> T | None:
600
600
bound = "VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Member | User | Guild | Role | GuildEmoji | AppEmoji" ,
601
601
)
602
602
_D = TypeVar ("_D" )
603
+ _Getter = Callable [[Any , int ], Any ]
604
+ _Fetcher = Callable [[Any , int ], Awaitable [Any ]]
603
605
604
606
605
607
# TODO: In version 3.0, remove the 'attr' and 'id' arguments.
@@ -681,20 +683,11 @@ async def get_or_fetch(
681
683
:exc:`Forbidden`
682
684
You do not have permission to fetch the object.
683
685
"""
684
- from discord import AppEmoji , Client , Guild , Member , Role , User , abc , emoji
686
+ from discord import Client , Guild , Member , Role , User
685
687
686
688
if object_id is None :
687
689
return default if default is not MISSING else None
688
690
689
- string_to_type = {
690
- "channel" : abc .GuildChannel ,
691
- "member" : Member ,
692
- "user" : User ,
693
- "guild" : Guild ,
694
- "emoji" : emoji ._EmojiTag ,
695
- "appemoji" : AppEmoji ,
696
- "role" : Role ,
697
- }
698
691
# Temporary backward compatibility for 'attr' and 'id'.
699
692
# This entire if block should be removed in version 3.0.
700
693
if attr is not MISSING or id is not MISSING or isinstance (object_type , str ):
@@ -709,7 +702,7 @@ async def get_or_fetch(
709
702
deprecated_id = id if id is not MISSING else object_id
710
703
711
704
if isinstance (deprecated_attr , str ):
712
- mapped_type = string_to_type .get (deprecated_attr .lower ())
705
+ mapped_type = _get_string_to_type_map () .get (deprecated_attr .lower ())
713
706
if mapped_type is None :
714
707
raise InvalidArgument (
715
708
f"Unknown type string '{ deprecated_attr } ' used. Please use a valid class like `discord.Member` instead."
@@ -738,7 +731,47 @@ async def get_or_fetch(
738
731
elif isinstance (obj , Guild ) and object_type is Guild :
739
732
raise InvalidArgument ("Guild cannot get_or_fetch Guild. Use Client instead." )
740
733
741
- getter_fetcher_map = {
734
+ try :
735
+ getter , fetcher = _get_getter_fetcher_map ()[object_type ]
736
+ except KeyError :
737
+ raise InvalidArgument (
738
+ f"Class { object_type .__name__ } cannot be used with discord.{ type (obj ).__name__ } .get_or_fetch()"
739
+ )
740
+
741
+ result = getter (obj , object_id )
742
+ if result is not None :
743
+ return result
744
+
745
+ try :
746
+ return await fetcher (obj , object_id )
747
+ except (HTTPException , ValueError ):
748
+ if default is not MISSING :
749
+ return default
750
+ raise
751
+
752
+
753
+ @functools .lru_cache (maxsize = 1 )
754
+ def _get_string_to_type_map () -> dict [str , type ]:
755
+ """Return a cached map of lowercase strings -> discord types."""
756
+ from discord import Guild , Member , Role , User , abc , emoji
757
+
758
+ return {
759
+ "channel" : abc .GuildChannel ,
760
+ "member" : Member ,
761
+ "user" : User ,
762
+ "guild" : Guild ,
763
+ "emoji" : emoji ._EmojiTag ,
764
+ "appemoji" : AppEmoji ,
765
+ "role" : Role ,
766
+ }
767
+
768
+
769
+ @functools .lru_cache (maxsize = 1 )
770
+ def _get_getter_fetcher_map () -> dict [type , tuple [_Getter , _Fetcher ]]:
771
+ """Return a cached map of type names -> (getter, fetcher) functions."""
772
+ from discord import Guild , Member , Role , User , abc , emoji
773
+
774
+ base_map : dict [type , tuple [_Getter , _Fetcher ]] = {
742
775
Member : (
743
776
lambda obj , oid : obj .get_member (oid ),
744
777
lambda obj , oid : obj .fetch_member (oid ),
@@ -764,26 +797,23 @@ async def get_or_fetch(
764
797
lambda obj , oid : obj .fetch_channel (oid ),
765
798
),
766
799
}
767
- try :
768
- base_type = next (
769
- base for base in getter_fetcher_map if issubclass (object_type , base )
770
- )
771
- getter , fetcher = getter_fetcher_map [base_type ]
772
- except KeyError :
773
- raise InvalidArgument (
774
- f"Class { object_type .__name__ } cannot be used with discord.{ type (obj ).__name__ } .get_or_fetch()"
775
- )
776
800
777
- result = getter (obj , object_id )
778
- if result is not None :
779
- return result
801
+ expanded : dict [type , tuple [_Getter , _Fetcher ]] = {}
802
+ for base , funcs in base_map .items ():
803
+ expanded [base ] = funcs
804
+ for subclass in _all_subclasses (base ):
805
+ if subclass not in expanded :
806
+ expanded [subclass ] = funcs
780
807
781
- try :
782
- return await fetcher (obj , object_id )
783
- except (HTTPException , ValueError ):
784
- if default is not MISSING :
785
- return default
786
- raise
808
+ return expanded
809
+
810
+
811
+ def _all_subclasses (cls : type ) -> set [type ]:
812
+ """Recursively collect all subclasses of a class."""
813
+ subs = set (cls .__subclasses__ ())
814
+ for sub in cls .__subclasses__ ():
815
+ subs |= _all_subclasses (sub )
816
+ return subs
787
817
788
818
789
819
def _unique (iterable : Iterable [T ]) -> list [T ]:
0 commit comments