Skip to content

Commit 182a255

Browse files
committed
refactor vizualize module into separate files
1 parent 8b62756 commit 182a255

File tree

10 files changed

+1004
-906
lines changed

10 files changed

+1004
-906
lines changed

src/surfaces/visualize/__init__.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@
5959
PlotCompatibilityError,
6060
VisualizationError,
6161
)
62-
from ._plots import (
63-
auto_plot,
64-
plot_contour,
65-
plot_convergence,
66-
plot_fitness_distribution,
67-
plot_latex,
68-
plot_multi_slice,
69-
plot_surface,
70-
)
62+
from ._auto import auto_plot
63+
from ._contour import plot_contour
64+
from ._convergence import plot_convergence
65+
from ._distribution import plot_fitness_distribution
66+
from ._latex import plot_latex
67+
from ._slices import plot_multi_slice
68+
from ._surface import plot_surface
7169

7270
__all__ = [
7371
# Discovery functions

src/surfaces/visualize/_auto.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Auto-selection of best visualization for a test function."""
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
10+
11+
if TYPE_CHECKING:
12+
import plotly.graph_objects as go
13+
14+
from ..test_functions._base_test_function import BaseTestFunction
15+
16+
from ._compatibility import _get_function_dimensions
17+
from ._convergence import plot_convergence
18+
from ._slices import plot_multi_slice
19+
from ._surface import plot_surface
20+
from ._utils import check_plotly
21+
22+
23+
def auto_plot(
24+
func: "BaseTestFunction",
25+
history: Optional[Union[List[float], Dict[str, List[float]]]] = None,
26+
resolution: int = 50,
27+
**kwargs,
28+
) -> "go.Figure":
29+
"""Automatically select and create the best visualization for a function.
30+
31+
Selection logic:
32+
- 2D functions: surface plot
33+
- N-D functions (N > 2): multi_slice plot
34+
- 1D functions: multi_slice plot (single panel)
35+
- If history provided: convergence plot
36+
37+
Args:
38+
func: A test function of any dimension.
39+
history: Optional optimization history. If provided, creates convergence plot.
40+
resolution: Resolution for grid-based plots.
41+
**kwargs: Additional arguments passed to the selected plot function.
42+
43+
Returns:
44+
Plotly Figure object.
45+
46+
Examples:
47+
>>> from surfaces.test_functions import SphereFunction
48+
>>> from surfaces.visualize import auto_plot
49+
>>> func = SphereFunction(n_dim=2)
50+
>>> fig = auto_plot(func) # Returns surface plot
51+
>>> fig.show()
52+
>>> func5d = SphereFunction(n_dim=5)
53+
>>> fig = auto_plot(func5d) # Returns multi_slice plot
54+
>>> fig.show()
55+
"""
56+
check_plotly()
57+
58+
# If history is provided, show convergence
59+
if history is not None:
60+
return plot_convergence(func, history, **kwargs)
61+
62+
n_dim = _get_function_dimensions(func)
63+
64+
# Select best plot for this function
65+
if n_dim == 2:
66+
return plot_surface(func, resolution=resolution, **kwargs)
67+
else:
68+
return plot_multi_slice(func, resolution=resolution, **kwargs)

src/surfaces/visualize/_contour.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""2D contour plot for 2D test functions."""
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
10+
11+
if TYPE_CHECKING:
12+
import plotly.graph_objects as go
13+
14+
from ..test_functions._base_test_function import BaseTestFunction
15+
16+
from ._utils import (
17+
DEFAULT_COLORSCALE,
18+
check_plotly,
19+
create_search_space_grid,
20+
evaluate_grid_2d,
21+
go,
22+
validate_plot,
23+
)
24+
25+
26+
def plot_contour(
27+
func: "BaseTestFunction",
28+
resolution: int = 50,
29+
bounds: Optional[Dict[str, Tuple[float, float]]] = None,
30+
title: Optional[str] = None,
31+
width: int = 700,
32+
height: int = 600,
33+
colorscale: Optional[str] = None,
34+
n_contours: int = 20,
35+
show_labels: bool = True,
36+
) -> "go.Figure":
37+
"""Create a 2D contour plot of a 2D objective function.
38+
39+
Args:
40+
func: A 2-dimensional test function.
41+
resolution: Number of points per dimension (default: 50).
42+
bounds: Optional custom bounds as {'x0': (min, max), 'x1': (min, max)}.
43+
title: Plot title. Defaults to function name.
44+
width: Plot width in pixels.
45+
height: Plot height in pixels.
46+
colorscale: Plotly colorscale name.
47+
n_contours: Number of contour levels.
48+
show_labels: Whether to show contour value labels.
49+
50+
Returns:
51+
Plotly Figure object.
52+
53+
Raises:
54+
PlotCompatibilityError: If function is not 2-dimensional.
55+
56+
Examples:
57+
>>> from surfaces.test_functions import RosenbrockFunction
58+
>>> from surfaces.visualize import plot_contour
59+
>>> func = RosenbrockFunction(n_dim=2)
60+
>>> fig = plot_contour(func)
61+
>>> fig.show()
62+
"""
63+
check_plotly()
64+
validate_plot(func, "contour")
65+
66+
# Create grid
67+
search_space = create_search_space_grid(func, resolution, bounds)
68+
param_names = list(search_space.keys())[:2]
69+
x_name, y_name = param_names[0], param_names[1]
70+
71+
x_values = search_space[x_name]
72+
y_values = search_space[y_name]
73+
74+
# Evaluate
75+
z_values = evaluate_grid_2d(func, x_values, y_values, x_name, y_name)
76+
77+
# Create figure
78+
fig = go.Figure(
79+
data=go.Contour(
80+
x=x_values,
81+
y=y_values,
82+
z=z_values,
83+
colorscale=colorscale or DEFAULT_COLORSCALE,
84+
contours=dict(
85+
showlabels=show_labels,
86+
labelfont=dict(size=10, color="white"),
87+
),
88+
ncontours=n_contours,
89+
)
90+
)
91+
92+
func_name = getattr(func, "name", type(func).__name__)
93+
fig.update_layout(
94+
title=title or f"{func_name} - Contour Plot",
95+
xaxis_title=x_name,
96+
yaxis_title=y_name,
97+
width=width,
98+
height=height,
99+
)
100+
101+
return fig
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Convergence plot for optimization history."""
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
10+
11+
import numpy as np
12+
13+
if TYPE_CHECKING:
14+
import plotly.graph_objects as go
15+
16+
from ..test_functions._base_test_function import BaseTestFunction
17+
18+
from ._errors import MissingDataError
19+
from ._utils import check_plotly, go
20+
21+
22+
def plot_convergence(
23+
func: "BaseTestFunction",
24+
history: Union[List[float], Dict[str, List[float]], np.ndarray],
25+
title: Optional[str] = None,
26+
width: int = 800,
27+
height: int = 500,
28+
log_scale: bool = False,
29+
show_best: bool = True,
30+
) -> "go.Figure":
31+
"""Create a convergence plot showing optimization progress.
32+
33+
Args:
34+
func: The test function (used for title and context).
35+
history: Objective values per evaluation. Can be:
36+
- List of values from a single run
37+
- Dict mapping run names to lists of values
38+
- 2D array where each row is a run
39+
title: Plot title. Defaults to function name.
40+
width: Plot width in pixels.
41+
height: Plot height in pixels.
42+
log_scale: Whether to use log scale for y-axis.
43+
show_best: Whether to show best-so-far instead of raw values.
44+
45+
Returns:
46+
Plotly Figure object.
47+
48+
Raises:
49+
MissingDataError: If history is empty.
50+
51+
Examples:
52+
>>> from surfaces.test_functions import SphereFunction
53+
>>> from surfaces.visualize import plot_convergence
54+
>>> func = SphereFunction(n_dim=2)
55+
>>> # Simulated optimization history
56+
>>> history = [10.0, 8.0, 5.0, 3.0, 1.0, 0.5, 0.1]
57+
>>> fig = plot_convergence(func, history)
58+
>>> fig.show()
59+
"""
60+
check_plotly()
61+
62+
if history is None or (hasattr(history, "__len__") and len(history) == 0):
63+
raise MissingDataError("convergence", "optimization history")
64+
65+
# Normalize history to dict format
66+
if isinstance(history, (list, np.ndarray)):
67+
if isinstance(history, np.ndarray) and history.ndim == 2:
68+
# Multiple runs as 2D array
69+
history = {f"Run {i+1}": list(row) for i, row in enumerate(history)}
70+
else:
71+
# Single run
72+
history = {"Optimization": list(history)}
73+
74+
fig = go.Figure()
75+
76+
for run_name, values in history.items():
77+
if show_best:
78+
# Cumulative minimum (best so far)
79+
best_so_far = np.minimum.accumulate(values)
80+
y_values = best_so_far
81+
y_label = "Best Objective Value"
82+
else:
83+
y_values = values
84+
y_label = "Objective Value"
85+
86+
evaluations = list(range(1, len(values) + 1))
87+
88+
fig.add_trace(
89+
go.Scatter(
90+
x=evaluations,
91+
y=y_values,
92+
mode="lines",
93+
name=run_name,
94+
line=dict(width=2),
95+
)
96+
)
97+
98+
func_name = getattr(func, "name", type(func).__name__)
99+
fig.update_layout(
100+
title=title or f"{func_name} - Convergence",
101+
xaxis_title="Evaluation",
102+
yaxis_title=y_label,
103+
width=width,
104+
height=height,
105+
)
106+
107+
if log_scale:
108+
fig.update_yaxes(type="log")
109+
110+
return fig

0 commit comments

Comments
 (0)