77from typing import Any , Callable , Optional
88
99import click
10+ import yaml
1011
1112from warnet .constants import (
1213 HOOK_NAME_KEY ,
1617 WARNET_USER_DIR_ENV_VAR ,
1718)
1819
20+
21+ class PluginError (Exception ):
22+ pass
23+
24+
1925hook_registry : set [Callable [..., Any ]] = set ()
2026imported_modules = {}
2127
2228
2329@click .group (name = "plugin" )
2430def plugin ():
31+ """Control plugins"""
2532 pass
2633
2734
@@ -33,6 +40,29 @@ def ls():
3340 click .secho ("Could not determine the plugin directory location." )
3441 click .secho ("Consider setting environment variable containing your project directory:" )
3542 click .secho (f"export { WARNET_USER_DIR_ENV_VAR } =/home/user/path/to/project/" , fg = "yellow" )
43+ sys .exit (1 )
44+
45+ for plugin , status in get_plugins_with_status (plugin_dir ):
46+ if status :
47+ click .secho (f"{ plugin .stem :<20} enabled" , fg = "green" )
48+ else :
49+ click .secho (f"{ plugin .stem :<20} disabled" , fg = "yellow" )
50+
51+
52+ @plugin .command ()
53+ @click .argument ("plugin" , type = str )
54+ @click .argument ("function" , type = str )
55+ def run (plugin : str , function : str ):
56+ module = imported_modules .get (f"plugins.{ plugin } " )
57+ if hasattr (module , function ):
58+ func = getattr (module , function )
59+ if callable (func ):
60+ result = func ()
61+ print (result )
62+ else :
63+ click .secho (f"{ function } in { module } is not callable." )
64+ else :
65+ click .secho (f"Could not find { function } in { module } " )
3666
3767
3868def api (func : Callable [..., Any ]) -> Callable [..., Any ]:
@@ -134,6 +164,11 @@ def load_user_modules() -> bool:
134164 if not plugin_dir or not plugin_dir .is_dir ():
135165 return was_successful_load
136166
167+ enabled_plugins = [plugin for plugin , enabled in get_plugins_with_status (plugin_dir ) if enabled ]
168+
169+ if not enabled_plugins :
170+ return was_successful_load
171+
137172 # Temporarily add the directory to sys.path for imports
138173 sys .path .insert (0 , str (plugin_dir ))
139174
@@ -146,15 +181,16 @@ def load_user_modules() -> bool:
146181 sys .modules [HOOKS_API_STEM ] = hooks_module
147182 hooks_spec .loader .exec_module (hooks_module )
148183
149- for file in plugin_dir .glob ("*.py" ):
150- if file .stem not in ("__init__" , HOOKS_API_STEM ):
151- module_name = f"{ PLUGINS_LABEL } .{ file .stem } "
152- spec = importlib .util .spec_from_file_location (module_name , file )
153- module = importlib .util .module_from_spec (spec )
154- imported_modules [module_name ] = module
155- sys .modules [module_name ] = module
156- spec .loader .exec_module (module )
157- was_successful_load = True
184+ for plugin_path in enabled_plugins :
185+ for file in plugin_path .glob ("*.py" ):
186+ if file .stem not in ("__init__" , HOOKS_API_STEM ):
187+ module_name = f"{ PLUGINS_LABEL } .{ file .stem } "
188+ spec = importlib .util .spec_from_file_location (module_name , file )
189+ module = importlib .util .module_from_spec (spec )
190+ imported_modules [module_name ] = module
191+ sys .modules [module_name ] = module
192+ spec .loader .exec_module (module )
193+ was_successful_load = True
158194
159195 # Remove the added path from sys.path
160196 sys .path .pop (0 )
@@ -190,3 +226,34 @@ def get_version(package_name: str) -> str:
190226 except PackageNotFoundError :
191227 print (f"Package not found: { package_name } " )
192228 sys .exit (1 )
229+
230+
231+ def open_yaml (path : Path ) -> dict :
232+ try :
233+ with open (path ) as file :
234+ return yaml .safe_load (file )
235+ except FileNotFoundError as e :
236+ raise PluginError (f"YAML file { path } not found." ) from e
237+ except yaml .YAMLError as e :
238+ raise PluginError (f"Error parsing yaml: { e } " ) from e
239+
240+
241+ def check_if_plugin_enabled (path : Path ) -> bool :
242+ enabled = None
243+ try :
244+ plugin_dict = open_yaml (path / Path ("plugin.yaml" ))
245+ enabled = plugin_dict .get ("enabled" )
246+ except PluginError as e :
247+ click .secho (e )
248+
249+ return bool (enabled )
250+
251+
252+ def get_plugins_with_status (plugin_dir : Path ) -> list [tuple [Path , bool ]]:
253+ candidates = [
254+ Path (os .path .join (plugin_dir , name ))
255+ for name in os .listdir (plugin_dir )
256+ if os .path .isdir (os .path .join (plugin_dir , name ))
257+ ]
258+ plugins = [plugin_dir for plugin_dir in candidates if any (plugin_dir .glob ("plugin.yaml" ))]
259+ return [(plugin , check_if_plugin_enabled (plugin )) for plugin in plugins ]
0 commit comments