Skip to content

Commit e71635b

Browse files
committed
fix scipts
1 parent de9bc2e commit e71635b

File tree

4 files changed

+77
-37
lines changed

4 files changed

+77
-37
lines changed

docs/scripts/generate_all_plots.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,33 @@
1616
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
1717

1818
from surfaces._visualize import matplotlib_heatmap, matplotlib_surface
19-
from surfaces.test_functions.mathematical.test_functions_2d import *
20-
from surfaces.test_functions.mathematical.test_functions_nd import *
19+
from surfaces.test_functions.algebraic import (
20+
# 2D functions
21+
AckleyFunction,
22+
BealeFunction,
23+
BoothFunction,
24+
BukinFunctionN6,
25+
CrossInTrayFunction,
26+
DropWaveFunction,
27+
EasomFunction,
28+
EggholderFunction,
29+
GoldsteinPriceFunction,
30+
GriewankFunction,
31+
HimmelblausFunction,
32+
HölderTableFunction,
33+
LangermannFunction,
34+
LeviFunctionN13,
35+
MatyasFunction,
36+
McCormickFunction,
37+
# ND functions
38+
RastriginFunction,
39+
RosenbrockFunction,
40+
SchafferFunctionN2,
41+
SimionescuFunction,
42+
SphereFunction,
43+
StyblinskiTangFunction,
44+
ThreeHumpCamelFunction,
45+
)
2146

2247
# Ensure output directories exist
2348
script_dir = os.path.dirname(os.path.abspath(__file__))
@@ -30,120 +55,135 @@
3055
FUNCTION_CONFIGS = {
3156
# 2D Functions
3257
"AckleyFunction": {
58+
"class": AckleyFunction,
3359
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
3460
"norm": None,
3561
},
3662
"BealeFunction": {
63+
"class": BealeFunction,
3764
"search_space": {"x0": np.arange(-4.5, 4.5, 0.1), "x1": np.arange(-4.5, 4.5, 0.1)},
3865
"norm": "color_log",
3966
},
4067
"BoothFunction": {
68+
"class": BoothFunction,
4169
"search_space": {"x0": np.arange(-10, 10, 0.2), "x1": np.arange(-10, 10, 0.2)},
4270
"norm": "color_log",
4371
},
4472
"BukinFunctionN6": {
73+
"class": BukinFunctionN6,
4574
"search_space": {"x0": np.arange(-15, -5, 0.2), "x1": np.arange(-3, 3, 0.1)},
4675
"norm": "color_log",
4776
},
4877
"CrossInTrayFunction": {
78+
"class": CrossInTrayFunction,
4979
"search_space": {"x0": np.arange(-10, 10, 0.2), "x1": np.arange(-10, 10, 0.2)},
5080
"norm": None,
5181
},
5282
"DropWaveFunction": {
83+
"class": DropWaveFunction,
5384
"search_space": {"x0": np.arange(-5.2, 5.2, 0.1), "x1": np.arange(-5.2, 5.2, 0.1)},
5485
"norm": None,
5586
},
5687
"EasomFunction": {
88+
"class": EasomFunction,
5789
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
5890
"norm": None,
5991
},
6092
"EggholderFunction": {
93+
"class": EggholderFunction,
6194
"search_space": {"x0": np.arange(-512, 512, 10), "x1": np.arange(-512, 512, 10)},
6295
"norm": None,
6396
},
6497
"GoldsteinPriceFunction": {
98+
"class": GoldsteinPriceFunction,
6599
"search_space": {"x0": np.arange(-2, 2, 0.05), "x1": np.arange(-2, 2, 0.05)},
66100
"norm": "color_log",
67101
},
68102
"HimmelblausFunction": {
103+
"class": HimmelblausFunction,
69104
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
70105
"norm": "color_log",
71106
},
72107
"HölderTableFunction": {
108+
"class": HölderTableFunction,
73109
"search_space": {"x0": np.arange(-10, 10, 0.2), "x1": np.arange(-10, 10, 0.2)},
74110
"norm": None,
75111
},
76112
"LangermannFunction": {
113+
"class": LangermannFunction,
77114
"search_space": {"x0": np.arange(0, 10, 0.1), "x1": np.arange(0, 10, 0.1)},
78115
"norm": None,
79116
},
80117
"LeviFunctionN13": {
118+
"class": LeviFunctionN13,
81119
"search_space": {"x0": np.arange(-10, 10, 0.2), "x1": np.arange(-10, 10, 0.2)},
82120
"norm": "color_log",
83121
},
84122
"MatyasFunction": {
123+
"class": MatyasFunction,
85124
"search_space": {"x0": np.arange(-10, 10, 0.2), "x1": np.arange(-10, 10, 0.2)},
86125
"norm": None,
87126
},
88127
"McCormickFunction": {
128+
"class": McCormickFunction,
89129
"search_space": {"x0": np.arange(-1.5, 4, 0.1), "x1": np.arange(-3, 4, 0.1)},
90130
"norm": None,
91131
},
92132
"SchafferFunctionN2": {
133+
"class": SchafferFunctionN2,
93134
"search_space": {"x0": np.arange(-100, 100, 2), "x1": np.arange(-100, 100, 2)},
94135
"norm": None,
95136
},
96137
"SimionescuFunction": {
138+
"class": SimionescuFunction,
97139
"search_space": {"x0": np.arange(-1.25, 1.25, 0.05), "x1": np.arange(-1.25, 1.25, 0.05)},
98140
"norm": None,
99141
},
100142
"ThreeHumpCamelFunction": {
143+
"class": ThreeHumpCamelFunction,
101144
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
102145
"norm": None,
103146
},
104147
# ND Functions (using 2D slices)
105148
"SphereFunction": {
149+
"class": SphereFunction,
106150
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
107151
"norm": None,
152+
"is_nd": True,
108153
},
109154
"RastriginFunction": {
155+
"class": RastriginFunction,
110156
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
111157
"norm": None,
158+
"is_nd": True,
112159
},
113160
"RosenbrockFunction": {
161+
"class": RosenbrockFunction,
114162
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
115163
"norm": "color_log",
164+
"is_nd": True,
116165
},
117166
"StyblinskiTangFunction": {
167+
"class": StyblinskiTangFunction,
118168
"search_space": {"x0": np.arange(-5, 5, 0.1), "x1": np.arange(-5, 5, 0.1)},
119169
"norm": None,
170+
"is_nd": True,
120171
},
121172
"GriewankFunction": {
173+
"class": GriewankFunction,
122174
"search_space": {"x0": np.arange(-10, 10, 0.2), "x1": np.arange(-10, 10, 0.2)},
123175
"norm": None,
176+
"is_nd": True,
124177
},
125178
}
126179

127180

128-
def get_function_instance(class_name):
129-
"""Get an instance of the function class by name."""
130-
try:
131-
func_class = globals()[class_name]
132-
if class_name in [
133-
"SphereFunction",
134-
"RastriginFunction",
135-
"RosenbrockFunction",
136-
"StyblinskiTangFunction",
137-
"GriewankFunction",
138-
]:
139-
# ND functions need n_dim parameter
140-
return func_class(n_dim=2, metric="loss")
141-
else:
142-
# 2D functions
143-
return func_class(metric="loss")
144-
except Exception as e:
145-
print(f"Error creating {class_name}: {e}")
146-
return None
181+
def get_function_instance(config):
182+
"""Get an instance of the function class from config."""
183+
func_class = config["class"]
184+
if config.get("is_nd", False):
185+
return func_class(n_dim=2, metric="loss")
186+
return func_class(metric="loss")
147187

148188

149189
def generate_plots():
@@ -155,9 +195,7 @@ def generate_plots():
155195

156196
try:
157197
# Get function instance
158-
func_instance = get_function_instance(func_name)
159-
if func_instance is None:
160-
continue
198+
func_instance = get_function_instance(config)
161199

162200
search_space = config["search_space"]
163201
norm = config["norm"]

docs/scripts/generate_ml_plots.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
# Add src to path to import surfaces
1818
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
1919

20-
from surfaces._visualize import (
20+
from surfaces._visualize import ( # noqa: E402
2121
plotly_dataset_hyperparameter_analysis,
2222
plotly_ml_hyperparameter_heatmap,
2323
)
24-
from surfaces.test_functions.machine_learning.tabular.classification.test_functions import *
25-
from surfaces.test_functions.machine_learning.tabular.regression.test_functions import *
24+
from surfaces.test_functions.machine_learning.tabular.classification.test_functions import ( # noqa: E402
25+
KNeighborsClassifierFunction,
26+
)
27+
from surfaces.test_functions.machine_learning.tabular.regression.test_functions import ( # noqa: E402
28+
GradientBoostingRegressorFunction,
29+
KNeighborsRegressorFunction,
30+
)
2631

2732
# Ensure output directories exist
2833
script_dir = os.path.dirname(os.path.abspath(__file__))

docs/scripts/generate_surface_plot.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
#!/usr/bin/env python3
22
"""Generate surface plots for the README."""
33

4+
import re
5+
46
import numpy as np
57
import plotly.graph_objects as go
68
from PIL import Image
7-
import re
89

910
from surfaces.test_functions.algebraic import (
1011
AckleyFunction,
11-
HimmelblausFunction,
12+
CrossInTrayFunction,
1213
DropWaveFunction,
1314
EggholderFunction,
14-
CrossInTrayFunction,
15+
HimmelblausFunction,
1516
RastriginFunction,
1617
)
1718

@@ -115,9 +116,7 @@ def generate_surface_plot(name, func, bounds, resolution=150):
115116
new_h = ((bottom - top) / orig_height) * svg_height
116117

117118
svg_content = re.sub(r'width="[^"]*"', f'width="{new_w:.0f}"', svg_content, count=1)
118-
svg_content = re.sub(
119-
r'height="[^"]*"', f'height="{new_h:.0f}"', svg_content, count=1
120-
)
119+
svg_content = re.sub(r'height="[^"]*"', f'height="{new_h:.0f}"', svg_content, count=1)
121120
svg_content = re.sub(
122121
r'viewBox="[^"]*"',
123122
f'viewBox="{new_x:.1f} {new_y:.1f} {new_w:.1f} {new_h:.1f}"',

tests/integration/test_readme_examples.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def extract_python_blocks(readme_path: Path) -> list[tuple[str, str]]:
1818
content = readme_path.read_text()
1919

2020
# Pattern: <summary><b>Name</b></summary> or ### Name followed by ```python
21-
pattern = r'(?:<summary><b>([^<]+)</b></summary>|### ([^\n]+))\s*\n+```python\n(.*?)```'
21+
pattern = r"(?:<summary><b>([^<]+)</b></summary>|### ([^\n]+))\s*\n+```python\n(.*?)```"
2222

2323
examples = []
2424
for match in re.finditer(pattern, content, re.DOTALL):
@@ -58,9 +58,7 @@ def test_readme_example(name: str, code: str):
5858
)
5959

6060
assert result.returncode == 0, (
61-
f"Example '{name}' failed:\n"
62-
f"stdout: {result.stdout}\n"
63-
f"stderr: {result.stderr}"
61+
f"Example '{name}' failed:\n" f"stdout: {result.stdout}\n" f"stderr: {result.stderr}"
6462
)
6563
finally:
6664
temp_path.unlink()

0 commit comments

Comments
 (0)