Skip to content

Commit 4574eb6

Browse files
committed
Add the hook api
1 parent 6290072 commit 4574eb6

File tree

7 files changed

+174
-1
lines changed

7 files changed

+174
-1
lines changed

resources/plugins/__init__.py

Whitespace-only changes.

resources/plugins/demo.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from hooks_api import post_status, pre_status
2+
3+
4+
@pre_status
5+
def print_something_wonderful():
6+
print("This has been a very pleasant day.")
7+
8+
9+
@post_status
10+
def print_something_afterwards():
11+
print("Status has run!")

src/warnet/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
NAMESPACES_FILE = "namespaces.yaml"
3838
DEFAULTS_NAMESPACE_FILE = "namespace-defaults.yaml"
3939

40+
# Plugin architecture
41+
PLUGINS_DIR = RESOURCES_DIR.joinpath("plugins")
42+
HOOK_NAME_KEY = "hook_name" # this lives as a key in object.__annotations__
43+
HOOKS_API_STEM = "hooks_api"
44+
HOOKS_API_FILE = HOOKS_API_STEM + ".py"
45+
4046
# Helm charts
4147
BITCOIN_CHART_LOCATION = str(CHARTS_DIR.joinpath("bitcoincore"))
4248
FORK_OBSERVER_CHART = str(CHARTS_DIR.joinpath("fork-observer"))

src/warnet/hooks.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import importlib.util
2+
import inspect
3+
import os
4+
import sys
5+
from importlib.metadata import PackageNotFoundError, version
6+
from pathlib import Path
7+
from typing import Any, Callable
8+
9+
from warnet.constants import HOOK_NAME_KEY, HOOKS_API_FILE, HOOKS_API_STEM
10+
11+
hook_registry: set[str] = set()
12+
imported_modules = {}
13+
14+
15+
def api(func: Callable[..., Any]) -> Callable[..., Any]:
16+
"""
17+
Functions with this decoration will have corresponding 'pre' and 'post' functions made
18+
available to the user via the 'plugins' directory.
19+
20+
Please ensure that @api is the innermost decorator:
21+
22+
```python
23+
@click.command() # outermost
24+
@api # innermost
25+
def my_function():
26+
pass
27+
```
28+
"""
29+
if func.__name__ in hook_registry:
30+
print(
31+
f"Cannot re-use function names in the Warnet plugin API -- "
32+
f"'{func.__name__}' has already been taken."
33+
)
34+
sys.exit(1)
35+
hook_registry.add(func.__name__)
36+
37+
if not imported_modules:
38+
load_user_modules()
39+
40+
pre_hooks, post_hooks = [], []
41+
for module_name in imported_modules:
42+
pre, post = find_hooks(module_name, func.__name__)
43+
pre_hooks.extend(pre)
44+
post_hooks.extend(post)
45+
46+
def wrapped(*args, **kwargs):
47+
for hook in pre_hooks:
48+
hook()
49+
result = func(*args, **kwargs)
50+
for hook in post_hooks:
51+
hook()
52+
return result
53+
54+
# Mimic the base function; helps make `click` happy
55+
wrapped.__name__ = func.__name__
56+
wrapped.__doc__ = func.__doc__
57+
58+
return wrapped
59+
60+
61+
def create_hooks(directory: Path):
62+
# Prepare directory and file
63+
os.makedirs(directory, exist_ok=True)
64+
init_file_path = os.path.join(directory, HOOKS_API_FILE)
65+
66+
with open(init_file_path, "w") as file:
67+
file.write(f"# API Version: {get_version('warnet')}")
68+
# For each enum variant, create a corresponding decorator function
69+
for hook in hook_registry:
70+
file.write(decorator_code.format(hook=hook, HOOK_NAME_KEY=HOOK_NAME_KEY))
71+
72+
73+
decorator_code = """
74+
75+
76+
def pre_{hook}(func):
77+
\"\"\"Functions with this decoration run before `{hook}`.\"\"\"
78+
func.__annotations__['{HOOK_NAME_KEY}'] = 'pre_{hook}'
79+
return func
80+
81+
82+
def post_{hook}(func):
83+
\"\"\"Functions with this decoration run after `{hook}`.\"\"\"
84+
func.__annotations__['{HOOK_NAME_KEY}'] = 'post_{hook}'
85+
return func
86+
"""
87+
88+
89+
def load_user_modules() -> bool:
90+
was_successful_load = False
91+
user_module_path = Path.cwd() / "plugins"
92+
93+
if not user_module_path.is_dir():
94+
print("No plugins folder found in the current directory")
95+
return was_successful_load
96+
97+
# Temporarily add the current directory to sys.path for imports
98+
sys.path.insert(0, str(Path.cwd()))
99+
100+
hooks_path = user_module_path / HOOKS_API_FILE
101+
if hooks_path.is_file():
102+
hooks_spec = importlib.util.spec_from_file_location(HOOKS_API_STEM, hooks_path)
103+
hooks_module = importlib.util.module_from_spec(hooks_spec)
104+
imported_modules[HOOKS_API_STEM] = hooks_module
105+
sys.modules[HOOKS_API_STEM] = hooks_module
106+
hooks_spec.loader.exec_module(hooks_module)
107+
108+
for file in user_module_path.glob("*.py"):
109+
if file.stem not in ("__init__", HOOKS_API_STEM):
110+
module_name = f"plugins.{file.stem}"
111+
spec = importlib.util.spec_from_file_location(module_name, file)
112+
module = importlib.util.module_from_spec(spec)
113+
imported_modules[module_name] = module
114+
sys.modules[module_name] = module
115+
spec.loader.exec_module(module)
116+
was_successful_load = True
117+
118+
# Remove the added path from sys.path
119+
sys.path.pop(0)
120+
return was_successful_load
121+
122+
123+
def find_hooks(module_name: str, func_name: str):
124+
module = imported_modules.get(module_name)
125+
pre_hooks = []
126+
post_hooks = []
127+
for _, func in inspect.getmembers(module, inspect.isfunction):
128+
if func.__annotations__.get(HOOK_NAME_KEY) == f"pre_{func_name}":
129+
pre_hooks.append(func)
130+
elif func.__annotations__.get(HOOK_NAME_KEY) == f"post_{func_name}":
131+
post_hooks.append(func)
132+
return pre_hooks, post_hooks
133+
134+
135+
def get_version(package_name: str) -> str:
136+
try:
137+
return version(package_name)
138+
except PackageNotFoundError:
139+
print(f"Package not found: {package_name}")
140+
sys.exit(1)

src/warnet/network.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from .bitcoin import _rpc
88
from .constants import (
99
NETWORK_DIR,
10+
PLUGINS_DIR,
1011
SCENARIOS_DIR,
1112
)
13+
from .hooks import create_hooks
1214
from .k8s import get_mission
1315

1416

@@ -48,6 +50,17 @@ def copy_scenario_defaults(directory: Path):
4850
)
4951

5052

53+
def copy_plugins_defaults(directory: Path):
54+
"""Create the project structure for a warnet project's scenarios"""
55+
copy_defaults(
56+
directory,
57+
PLUGINS_DIR.name,
58+
PLUGINS_DIR,
59+
["__pycache__", "__init__"],
60+
)
61+
create_hooks(directory / PLUGINS_DIR.name)
62+
63+
5164
def is_connection_manual(peer):
5265
# newer nodes specify a "connection_type"
5366
return bool(peer.get("connection_type") == "manual" or peer.get("addnode") is True)

src/warnet/project.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
KUBECTL_DOWNLOAD_URL_STUB,
2727
)
2828
from .graph import inquirer_create_network
29-
from .network import copy_network_defaults, copy_scenario_defaults
29+
from .network import copy_network_defaults, copy_plugins_defaults, copy_scenario_defaults
3030

3131

3232
@click.command()
@@ -387,6 +387,7 @@ def create_warnet_project(directory: Path, check_empty: bool = False):
387387
try:
388388
copy_network_defaults(directory)
389389
copy_scenario_defaults(directory)
390+
copy_plugins_defaults(directory)
390391
click.echo(f"Copied network example files to {directory}/networks")
391392
click.echo(f"Created warnet project structure in {directory}")
392393
except Exception as e:

src/warnet/status.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from rich.text import Text
99
from urllib3.exceptions import MaxRetryError
1010

11+
from .hooks import api
1112
from .k8s import get_mission
1213
from .network import _connected
1314

1415

1516
@click.command()
17+
@api
1618
def status():
1719
"""Display the unified status of the Warnet network and active scenarios"""
1820
console = Console()

0 commit comments

Comments
 (0)