Skip to content

Commit 4fb893c

Browse files
authored
Singledispatch entrypoints (#38)
I was playing with how we could set up functions in a modular fashion without having to add a function in the corresponding class. cattrs mentions the use of single dispatch functions for their structure and unstructure functions. We can do the same for functions like plot(), where the plot function should be seen as a module, and not as part of the class itself. I have combined this functionality with the entry_points function to show that a user or plugin package could in theory also add functionality. For example, a user comes with a specific entry point for plotting a specific package, other than the default way that we normally do.
1 parent 40cb53f commit 4fb893c

File tree

6 files changed

+124
-40
lines changed

6 files changed

+124
-40
lines changed

flopy4/singledispatch/__init__.py

Whitespace-only changes.

flopy4/singledispatch/plot.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from functools import singledispatch
2+
from typing import Any
3+
4+
5+
@singledispatch
6+
def plot(obj, **kwargs) -> Any:
7+
raise NotImplementedError(
8+
"plot method not implemented for type {}".format(type(obj))
9+
)

flopy4/singledispatch/plot_int.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Any
2+
3+
from flopy4.singledispatch.plot import plot
4+
5+
6+
@plot.register
7+
def _(v: int, **kwargs) -> Any:
8+
print(f"Plotting a model with kwargs: {kwargs}")
9+
return v

pixi.lock

Lines changed: 40 additions & 40 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ ignore = [
112112
"E741", # ambiguous variable name
113113
]
114114

115+
[project.entry-points.flopy4]
116+
plot = "flopy4.singledispatch.plot_int"
117+
115118
[tool.pixi.project]
116119
channels = ["conda-forge"]
117120
platforms = ["win-64", "linux-64", "osx-64"]
@@ -145,3 +148,4 @@ test = { cmd = "pytest -v -n auto" }
145148

146149
[tool.pixi.feature.lint.tasks]
147150
lint = { cmd = "ruff check ." }
151+

test/test_singledispatch.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import ast
2+
import inspect
3+
import subprocess
4+
import sys
5+
from importlib.metadata import entry_points
6+
7+
import pytest
8+
9+
from flopy4.singledispatch.plot import plot
10+
11+
12+
def get_function_body(func):
13+
source = inspect.getsource(func)
14+
parsed = ast.parse(source)
15+
for node in ast.walk(parsed):
16+
if isinstance(node, ast.FunctionDef):
17+
return ast.get_source_segment(source, node.body[0])
18+
raise ValueError("Function body not found")
19+
20+
21+
def run_test_in_subprocess(test_func):
22+
def wrapper():
23+
test_func_source = get_function_body(test_func)
24+
test_code = f"""
25+
import pytest
26+
from importlib.metadata import entry_points
27+
from flopy4.singledispatch.plot import plot
28+
29+
{test_func_source}
30+
31+
"""
32+
result = subprocess.run(
33+
[sys.executable, "-c", test_code], capture_output=True, text=True
34+
)
35+
if result.returncode != 0:
36+
print(result.stdout)
37+
print(result.stderr)
38+
assert result.returncode == 0, f"Test failed: {test_func.__name__}"
39+
40+
return wrapper
41+
42+
43+
@run_test_in_subprocess
44+
def test_register_singledispatch_with_entrypoints():
45+
eps = entry_points(group="flopy4", name="plot")
46+
for ep in eps:
47+
ep.load()
48+
49+
# should not throw an error, because plot_int was loaded via entry points
50+
return_val = plot(5)
51+
assert return_val == 5
52+
with pytest.raises(NotImplementedError):
53+
plot("five")
54+
55+
56+
@run_test_in_subprocess
57+
def test_register_singledispatch_without_entrypoints():
58+
# should throw an error, because plot_int was not loaded via entry points
59+
with pytest.raises(NotImplementedError):
60+
plot(5)
61+
with pytest.raises(NotImplementedError):
62+
plot("five")

0 commit comments

Comments
 (0)