Skip to content

Commit 5a42fff

Browse files
Rework extension loading (#1423)
Co-authored-by: Lala Sabathil <[email protected]>
1 parent b3137a7 commit 5a42fff

File tree

1 file changed

+157
-11
lines changed

1 file changed

+157
-11
lines changed

discord/cog.py

Lines changed: 157 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
import importlib
2828
import inspect
29+
import os
30+
import pathlib
2931
import sys
3032
import types
3133
from typing import (
@@ -40,6 +42,7 @@
4042
Tuple,
4143
Type,
4244
TypeVar,
45+
Union,
4346
)
4447

4548
import discord.utils
@@ -739,7 +742,14 @@ def _resolve_name(self, name: str, package: Optional[str]) -> str:
739742
except ImportError:
740743
raise errors.ExtensionNotFound(name)
741744

742-
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
745+
def load_extension(
746+
self,
747+
name: str,
748+
*,
749+
package: Optional[str] = None,
750+
recursive: bool = False,
751+
store: bool = True,
752+
) -> Optional[Union[Dict[str, Union[Exception, bool]], List[str]]]:
743753
"""Loads an extension.
744754
745755
An extension is a python module that contains commands, cogs, or
@@ -749,21 +759,41 @@ def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
749759
the entry point on what to do when the extension is loaded. This entry
750760
point must have a single argument, the ``bot``.
751761
762+
The extension passed can either be the direct name of a file within
763+
the current working directory or a folder that contains multiple extensions.
764+
752765
Parameters
753-
------------
766+
-----------
754767
name: :class:`str`
755-
The extension name to load. It must be dot separated like
756-
regular Python imports if accessing a sub-module. e.g.
768+
The extension or folder name to load. It must be dot separated
769+
like regular Python imports if accessing a sub-module. e.g.
757770
``foo.test`` if you want to import ``foo/test.py``.
758771
package: Optional[:class:`str`]
759772
The package name to resolve relative imports with.
760-
This is required when loading an extension using a relative path, e.g ``.foo.test``.
773+
This is required when loading an extension using a relative
774+
path, e.g ``.foo.test``.
761775
Defaults to ``None``.
762776
763777
.. versionadded:: 1.7
778+
recursive: Optional[:class:`bool`]
779+
If subdirectories under the given head directory should be
780+
recursively loaded.
781+
Defaults to ``False``.
782+
783+
.. versionadded:: 2.0
784+
store: Optional[:class:`bool`]
785+
If exceptions should be stored or raised. If set to ``True``,
786+
all exceptions encountered will be stored in a returned dictionary
787+
as a load status. If set to ``False``, if any exceptions are
788+
encountered they will be raised and the bot will be closed.
789+
If no exceptions are encountered, a list of loaded
790+
extension names will be returned.
791+
Defaults to ``True``.
792+
793+
.. versionadded:: 2.0
764794
765795
Raises
766-
--------
796+
-------
767797
ExtensionNotFound
768798
The extension could not be imported.
769799
This is also raised if the name of the extension could not
@@ -774,17 +804,133 @@ def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
774804
The extension does not have a setup function.
775805
ExtensionFailed
776806
The extension or its setup function had an execution error.
807+
808+
Returns
809+
--------
810+
Optional[Union[Dict[:class:`str`, Union[:exc:`errors.ExtensionError`, :class:`bool`]], List[:class:`str`]]]
811+
If the store parameter is set to ``True``, a dictionary will be returned that
812+
contains keys to represent the loaded extension names. The values bound to
813+
each key can either be an exception that occurred when loading that extension
814+
or a ``True`` boolean representing a successful load. If the store parameter
815+
is set to ``False``, either a list containing a list of loaded extensions or
816+
nothing due to an encountered exception.
777817
"""
778818

779819
name = self._resolve_name(name, package)
820+
780821
if name in self.__extensions:
781-
raise errors.ExtensionAlreadyLoaded(name)
822+
exc = errors.ExtensionAlreadyLoaded(name)
823+
final_out = {name: exc} if store else exc
824+
# This indicates that there is neither an extension nor folder here
825+
elif (spec := importlib.util.find_spec(name)) is None:
826+
exc = errors.ExtensionNotFound(name)
827+
final_out = {name: exc} if store else exc
828+
# This indicates we've found an extension file to load, and we need to store any exceptions
829+
elif spec.has_location and store:
830+
try:
831+
self._load_from_module_spec(spec, name)
832+
except Exception as exc:
833+
final_out = {name: exc}
834+
else:
835+
final_out = {name: True}
836+
# This indicates we've found an extension file to load, and any encountered exceptions can be raised
837+
elif spec.has_location:
838+
self._load_from_module_spec(spec, name)
839+
final_out = [name]
840+
# This indicates we've been given a folder because the ModuleSpec exists but is not a file
841+
else:
842+
# Split the directory path and join it to get an os-native Path object
843+
path = pathlib.Path(os.path.join(*name.split(".")))
844+
glob = path.rglob if recursive else path.glob
845+
final_out = {} if store else []
846+
847+
# Glob all files with a pattern to gather all .py files that don't start with _
848+
for ext_file in glob("[!_]*.py"):
849+
# Gets all parts leading to the directory minus the file name
850+
parts = list(ext_file.parts[:-1])
851+
# Gets the file name without the extension
852+
parts.append(ext_file.stem)
853+
loaded = self.load_extension(".".join(parts))
854+
final_out.update(loaded) if store else final_out.extend(loaded)
855+
856+
if isinstance(final_out, Exception):
857+
raise final_out
858+
else:
859+
return final_out
782860

783-
spec = importlib.util.find_spec(name)
784-
if spec is None:
785-
raise errors.ExtensionNotFound(name)
861+
def load_extensions(
862+
self,
863+
*names: str,
864+
package: Optional[str] = None,
865+
recursive: bool = False,
866+
store: bool = True,
867+
) -> Optional[Union[Dict[str, Union[Exception, bool]], List[str]]]:
868+
"""Loads multiple extensions at once.
869+
870+
This method simplifies the process of loading multiple
871+
extensions by handling the looping of ``load_extension``.
872+
873+
Parameters
874+
-----------
875+
names: :class:`str`
876+
The extension or folder names to load. It must be dot separated
877+
like regular Python imports if accessing a sub-module. e.g.
878+
``foo.test`` if you want to import ``foo/test.py``.
879+
package: Optional[:class:`str`]
880+
The package name to resolve relative imports with.
881+
This is required when loading an extension using a relative
882+
path, e.g ``.foo.test``.
883+
Defaults to ``None``.
884+
885+
.. versionadded:: 1.7
886+
recursive: Optional[:class:`bool`]
887+
If subdirectories under the given head directory should be
888+
recursively loaded.
889+
Defaults to ``False``.
890+
891+
.. versionadded:: 2.0
892+
store: Optional[:class:`bool`]
893+
If exceptions should be stored or raised. If set to ``True``,
894+
all exceptions encountered will be stored in a returned dictionary
895+
as a load status. If set to ``False``, if any exceptions are
896+
encountered they will be raised and the bot will be closed.
897+
If no exceptions are encountered, a list of loaded
898+
extension names will be returned.
899+
Defaults to ``True``.
900+
901+
.. versionadded:: 2.0
902+
903+
Raises
904+
--------
905+
ExtensionNotFound
906+
A given extension could not be imported.
907+
This is also raised if the name of the extension could not
908+
be resolved using the provided ``package`` parameter.
909+
ExtensionAlreadyLoaded
910+
A given extension is already loaded.
911+
NoEntryPointError
912+
A given extension does not have a setup function.
913+
ExtensionFailed
914+
A given extension or its setup function had an execution error.
915+
916+
Returns
917+
--------
918+
Optional[Union[Dict[:class:`str`, Union[:exc:`errors.ExtensionError`, :class:`bool`]], List[:class:`str`]]]
919+
If the store parameter is set to ``True``, a dictionary will be returned that
920+
contains keys to represent the loaded extension names. The values bound to
921+
each key can either be an exception that occurred when loading that extension
922+
or a ``True`` boolean representing a successful load. If the store parameter
923+
is set to ``False``, either a list containing names of loaded extensions or
924+
nothing due to an encountered exception.
925+
"""
926+
927+
loaded_extensions = {} if store else []
928+
929+
for ext_path in names:
930+
loaded = self.load_extension(ext_path, package=package, recursive=recursive, store=store)
931+
loaded_extensions.update(loaded) if store else loaded_extensions.extend(loaded)
786932

787-
self._load_from_module_spec(spec, name)
933+
return loaded_extensions
788934

789935
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
790936
"""Unloads an extension.

0 commit comments

Comments
 (0)