@@ -600,6 +600,8 @@ def get(iterable: Iterable[T], **attrs: Any) -> T | None:
600600 bound = "VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Member | User | Guild | Role | GuildEmoji | AppEmoji" ,
601601)
602602_D = TypeVar ("_D" )
603+ _Getter = Callable [[Any , int ], Any ]
604+ _Fetcher = Callable [[Any , int ], Awaitable [Any ]]
603605
604606
605607# TODO: In version 3.0, remove the 'attr' and 'id' arguments.
@@ -681,20 +683,11 @@ async def get_or_fetch(
681683 :exc:`Forbidden`
682684 You do not have permission to fetch the object.
683685 """
684- from discord import AppEmoji , Client , Guild , Member , Role , User , abc , emoji
686+ from discord import Client , Guild , Member , Role , User
685687
686688 if object_id is None :
687689 return default if default is not MISSING else None
688690
689- string_to_type = {
690- "channel" : abc .GuildChannel ,
691- "member" : Member ,
692- "user" : User ,
693- "guild" : Guild ,
694- "emoji" : emoji ._EmojiTag ,
695- "appemoji" : AppEmoji ,
696- "role" : Role ,
697- }
698691 # Temporary backward compatibility for 'attr' and 'id'.
699692 # This entire if block should be removed in version 3.0.
700693 if attr is not MISSING or id is not MISSING or isinstance (object_type , str ):
@@ -709,7 +702,7 @@ async def get_or_fetch(
709702 deprecated_id = id if id is not MISSING else object_id
710703
711704 if isinstance (deprecated_attr , str ):
712- mapped_type = string_to_type .get (deprecated_attr .lower ())
705+ mapped_type = _get_string_to_type_map () .get (deprecated_attr .lower ())
713706 if mapped_type is None :
714707 raise InvalidArgument (
715708 f"Unknown type string '{ deprecated_attr } ' used. Please use a valid class like `discord.Member` instead."
@@ -738,7 +731,47 @@ async def get_or_fetch(
738731 elif isinstance (obj , Guild ) and object_type is Guild :
739732 raise InvalidArgument ("Guild cannot get_or_fetch Guild. Use Client instead." )
740733
741- getter_fetcher_map = {
734+ try :
735+ getter , fetcher = _get_getter_fetcher_map ()[object_type ]
736+ except KeyError :
737+ raise InvalidArgument (
738+ f"Class { object_type .__name__ } cannot be used with discord.{ type (obj ).__name__ } .get_or_fetch()"
739+ )
740+
741+ result = getter (obj , object_id )
742+ if result is not None :
743+ return result
744+
745+ try :
746+ return await fetcher (obj , object_id )
747+ except (HTTPException , ValueError ):
748+ if default is not MISSING :
749+ return default
750+ raise
751+
752+
753+ @functools .lru_cache (maxsize = 1 )
754+ def _get_string_to_type_map () -> dict [str , type ]:
755+ """Return a cached map of lowercase strings -> discord types."""
756+ from discord import Guild , Member , Role , User , abc , emoji
757+
758+ return {
759+ "channel" : abc .GuildChannel ,
760+ "member" : Member ,
761+ "user" : User ,
762+ "guild" : Guild ,
763+ "emoji" : emoji ._EmojiTag ,
764+ "appemoji" : AppEmoji ,
765+ "role" : Role ,
766+ }
767+
768+
769+ @functools .lru_cache (maxsize = 1 )
770+ def _get_getter_fetcher_map () -> dict [type , tuple [_Getter , _Fetcher ]]:
771+ """Return a cached map of type names -> (getter, fetcher) functions."""
772+ from discord import Guild , Member , Role , User , abc , emoji
773+
774+ base_map : dict [type , tuple [_Getter , _Fetcher ]] = {
742775 Member : (
743776 lambda obj , oid : obj .get_member (oid ),
744777 lambda obj , oid : obj .fetch_member (oid ),
@@ -764,26 +797,23 @@ async def get_or_fetch(
764797 lambda obj , oid : obj .fetch_channel (oid ),
765798 ),
766799 }
767- try :
768- base_type = next (
769- base for base in getter_fetcher_map if issubclass (object_type , base )
770- )
771- getter , fetcher = getter_fetcher_map [base_type ]
772- except KeyError :
773- raise InvalidArgument (
774- f"Class { object_type .__name__ } cannot be used with discord.{ type (obj ).__name__ } .get_or_fetch()"
775- )
776800
777- result = getter (obj , object_id )
778- if result is not None :
779- return result
801+ expanded : dict [type , tuple [_Getter , _Fetcher ]] = {}
802+ for base , funcs in base_map .items ():
803+ expanded [base ] = funcs
804+ for subclass in _all_subclasses (base ):
805+ if subclass not in expanded :
806+ expanded [subclass ] = funcs
780807
781- try :
782- return await fetcher (obj , object_id )
783- except (HTTPException , ValueError ):
784- if default is not MISSING :
785- return default
786- raise
808+ return expanded
809+
810+
811+ def _all_subclasses (cls : type ) -> set [type ]:
812+ """Recursively collect all subclasses of a class."""
813+ subs = set (cls .__subclasses__ ())
814+ for sub in cls .__subclasses__ ():
815+ subs |= _all_subclasses (sub )
816+ return subs
787817
788818
789819def _unique (iterable : Iterable [T ]) -> list [T ]:
0 commit comments