Skip to content

Commit 23823a6

Browse files
committed
update hooks commands
1 parent 0f408ec commit 23823a6

File tree

3 files changed

+54
-10
lines changed

3 files changed

+54
-10
lines changed

src/warnet/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@
3838
DEFAULTS_NAMESPACE_FILE = "namespace-defaults.yaml"
3939

4040
# Plugin architecture
41-
PLUGINS_DIR = RESOURCES_DIR.joinpath("plugins")
41+
PLUGINS_LABEL = "plugins"
42+
PLUGINS_DIR = RESOURCES_DIR.joinpath(PLUGINS_LABEL)
4243
HOOK_NAME_KEY = "hook_name" # this lives as a key in object.__annotations__
4344
HOOKS_API_STEM = "hooks_api"
4445
HOOKS_API_FILE = HOOKS_API_STEM + ".py"
46+
WARNET_USER_DIR_ENV_VAR = "WARNET_USER_DIR"
4547

4648
# Helm charts
4749
BITCOIN_CHART_LOCATION = str(CHARTS_DIR.joinpath("bitcoincore"))

src/warnet/hooks.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,37 @@
44
import sys
55
from importlib.metadata import PackageNotFoundError, version
66
from pathlib import Path
7-
from typing import Any, Callable
7+
from typing import Any, Callable, Optional
88

9-
from warnet.constants import HOOK_NAME_KEY, HOOKS_API_FILE, HOOKS_API_STEM
9+
import click
10+
11+
from warnet.constants import (
12+
HOOK_NAME_KEY,
13+
HOOKS_API_FILE,
14+
HOOKS_API_STEM,
15+
PLUGINS_LABEL,
16+
WARNET_USER_DIR_ENV_VAR,
17+
)
1018

1119
hook_registry: set[Callable[..., Any]] = set()
1220
imported_modules = {}
1321

1422

23+
@click.group(name="plugin")
24+
def plugin():
25+
pass
26+
27+
28+
@plugin.command()
29+
def ls():
30+
plugin_dir = get_plugin_directory()
31+
32+
if not plugin_dir:
33+
click.secho("Could not determine the plugin directory location.")
34+
click.secho("Consider setting environment variable containing your project directory:")
35+
click.secho(f"export {WARNET_USER_DIR_ENV_VAR}=/home/user/path/to/project/", fg="yellow")
36+
37+
1538
def api(func: Callable[..., Any]) -> Callable[..., Any]:
1639
"""
1740
Functions with this decoration will have corresponding 'pre' and 'post' functions made
@@ -73,6 +96,9 @@ def create_hooks(directory: Path):
7396
)
7497
)
7598

99+
click.secho("\nConsider setting environment variable containing your project directory:")
100+
click.secho(f"export {WARNET_USER_DIR_ENV_VAR}={directory.parent}\n", fg="yellow")
101+
76102

77103
decorator_code = """
78104
@@ -102,25 +128,27 @@ def post_{hook}(func):
102128

103129
def load_user_modules() -> bool:
104130
was_successful_load = False
105-
user_module_path = Path.cwd() / "plugins"
106131

107-
if not user_module_path.is_dir():
132+
plugin_dir = get_plugin_directory()
133+
134+
if not plugin_dir or not plugin_dir.is_dir():
108135
return was_successful_load
109136

110-
# Temporarily add the current directory to sys.path for imports
111-
sys.path.insert(0, str(Path.cwd()))
137+
# Temporarily add the directory to sys.path for imports
138+
sys.path.insert(0, str(plugin_dir))
139+
140+
hooks_path = plugin_dir / HOOKS_API_FILE
112141

113-
hooks_path = user_module_path / HOOKS_API_FILE
114142
if hooks_path.is_file():
115143
hooks_spec = importlib.util.spec_from_file_location(HOOKS_API_STEM, hooks_path)
116144
hooks_module = importlib.util.module_from_spec(hooks_spec)
117145
imported_modules[HOOKS_API_STEM] = hooks_module
118146
sys.modules[HOOKS_API_STEM] = hooks_module
119147
hooks_spec.loader.exec_module(hooks_module)
120148

121-
for file in user_module_path.glob("*.py"):
149+
for file in plugin_dir.glob("*.py"):
122150
if file.stem not in ("__init__", HOOKS_API_STEM):
123-
module_name = f"plugins.{file.stem}"
151+
module_name = f"{PLUGINS_LABEL}.{file.stem}"
124152
spec = importlib.util.spec_from_file_location(module_name, file)
125153
module = importlib.util.module_from_spec(spec)
126154
imported_modules[module_name] = module
@@ -145,6 +173,17 @@ def find_hooks(module_name: str, func_name: str):
145173
return pre_hooks, post_hooks
146174

147175

176+
def get_plugin_directory() -> Optional[Path]:
177+
user_dir = os.getenv(WARNET_USER_DIR_ENV_VAR)
178+
179+
plugin_dir = Path(user_dir) / PLUGINS_LABEL if user_dir else Path.cwd() / PLUGINS_LABEL
180+
181+
if plugin_dir and plugin_dir.is_dir():
182+
return plugin_dir
183+
else:
184+
return None
185+
186+
148187
def get_version(package_name: str) -> str:
149188
try:
150189
return version(package_name)

src/warnet/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .dashboard import dashboard
77
from .deploy import deploy
88
from .graph import create, graph, import_network
9+
from .graph import create, graph
10+
from .hooks import plugin
911
from .image import image
1012
from .ln import ln
1113
from .project import init, new, setup
@@ -37,6 +39,7 @@ def cli():
3739
cli.add_command(status)
3840
cli.add_command(stop)
3941
cli.add_command(create)
42+
cli.add_command(plugin)
4043

4144

4245
if __name__ == "__main__":

0 commit comments

Comments
 (0)