Skip to content

Commit 830c2b2

Browse files
committed
move scripts
1 parent 2d12611 commit 830c2b2

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python3
2+
"""Generate surface plots for the README."""
3+
4+
import numpy as np
5+
import plotly.graph_objects as go
6+
from PIL import Image
7+
import re
8+
9+
from surfaces.test_functions.algebraic import (
10+
AckleyFunction,
11+
HimmelblausFunction,
12+
DropWaveFunction,
13+
EggholderFunction,
14+
CrossInTrayFunction,
15+
RastriginFunction,
16+
)
17+
18+
# Test functions to visualize (tighter bounds for some)
19+
FUNCTIONS = [
20+
("ackley", AckleyFunction(), (-3, 3)),
21+
("himmelblau", HimmelblausFunction(), (-5, 5)),
22+
("drop_wave", DropWaveFunction(), (-3, 3)),
23+
("eggholder", EggholderFunction(), (-512, 512)),
24+
("cross_in_tray", CrossInTrayFunction(), (-5, 5)),
25+
("rastrigin", RastriginFunction(n_dim=2), (-5.12, 5.12)),
26+
]
27+
28+
OUTPUT_DIR = "../source/_static"
29+
30+
31+
def generate_surface_plot(name, func, bounds, resolution=150):
32+
"""Generate a surface plot SVG for a test function."""
33+
min_val, max_val = bounds
34+
35+
# Create mesh grid
36+
x = np.linspace(min_val, max_val, resolution)
37+
y = np.linspace(min_val, max_val, resolution)
38+
X, Y = np.meshgrid(x, y)
39+
40+
# Evaluate function
41+
points = np.column_stack([X.ravel(), Y.ravel()])
42+
Z = func._batch_objective(points).reshape(X.shape)
43+
44+
# Create figure with Jet_r colorscale (red = low/optimal, blue = high)
45+
fig = go.Figure(
46+
data=[
47+
go.Surface(
48+
x=X,
49+
y=Y,
50+
z=Z,
51+
colorscale="Jet_r",
52+
showscale=False,
53+
lighting={
54+
"ambient": 0.6,
55+
"diffuse": 0.8,
56+
"specular": 0.2,
57+
"roughness": 0.5,
58+
},
59+
)
60+
]
61+
)
62+
63+
# Layout
64+
fig.update_layout(
65+
scene={
66+
"xaxis": {"visible": False},
67+
"yaxis": {"visible": False},
68+
"zaxis": {"visible": False},
69+
"bgcolor": "rgba(0,0,0,0)",
70+
"camera": {
71+
"eye": {"x": 1.3, "y": 1.3, "z": 0.6},
72+
},
73+
"aspectmode": "manual",
74+
"aspectratio": {"x": 1, "y": 1, "z": 0.5},
75+
"domain": {"x": [0, 1], "y": [0, 1]},
76+
},
77+
margin={"l": 0, "r": 0, "t": 0, "b": 0, "pad": 0},
78+
paper_bgcolor="rgba(0,0,0,0)",
79+
width=900,
80+
height=600,
81+
autosize=False,
82+
)
83+
84+
# Save temp PNG for bbox calculation
85+
temp_png = f"/tmp/{name}_temp.png"
86+
fig.write_image(temp_png, scale=2)
87+
88+
# Get crop bounds from PNG
89+
img = Image.open(temp_png).convert("RGBA")
90+
bbox = img.getbbox()
91+
92+
if bbox:
93+
padding = 20
94+
left = max(0, bbox[0] - padding)
95+
top = max(0, bbox[1] - padding)
96+
right = min(img.width, bbox[2] + padding)
97+
bottom = min(img.height, bbox[3] + padding)
98+
else:
99+
left, top, right, bottom = 0, 0, img.width, img.height
100+
101+
# Save and crop SVG
102+
temp_svg = f"/tmp/{name}_temp.svg"
103+
fig.write_image(temp_svg)
104+
105+
with open(temp_svg, "r") as f:
106+
svg_content = f.read()
107+
108+
# Calculate crop for SVG (original was 1800x1200 at scale=2)
109+
orig_width, orig_height = 1800, 1200
110+
svg_width, svg_height = 900, 600
111+
112+
new_x = (left / orig_width) * svg_width
113+
new_y = (top / orig_height) * svg_height
114+
new_w = ((right - left) / orig_width) * svg_width
115+
new_h = ((bottom - top) / orig_height) * svg_height
116+
117+
svg_content = re.sub(r'width="[^"]*"', f'width="{new_w:.0f}"', svg_content, count=1)
118+
svg_content = re.sub(
119+
r'height="[^"]*"', f'height="{new_h:.0f}"', svg_content, count=1
120+
)
121+
svg_content = re.sub(
122+
r'viewBox="[^"]*"',
123+
f'viewBox="{new_x:.1f} {new_y:.1f} {new_w:.1f} {new_h:.1f}"',
124+
svg_content,
125+
count=1,
126+
)
127+
128+
output_path = f"{OUTPUT_DIR}/{name}_surface.svg"
129+
with open(output_path, "w") as f:
130+
f.write(svg_content)
131+
132+
print(f"Saved: {output_path}")
133+
134+
135+
if __name__ == "__main__":
136+
for name, func, bounds in FUNCTIONS:
137+
print(f"Generating {name}...")
138+
generate_surface_plot(name, func, bounds)
139+
140+
print("\nDone! Generated 6 SVG files.")

0 commit comments

Comments
 (0)