Skip to content

Commit 18671d9

Browse files
committed
add dashboard for surrogate training and overview
1 parent da0e2f4 commit 18671d9

File tree

14 files changed

+2298
-0
lines changed

14 files changed

+2298
-0
lines changed

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ surrogate-train = [
8181
"onnxruntime>=1.16.0",
8282
"skl2onnx>=1.16.0",
8383
]
84+
# Surrogate dashboard (developer tool)
85+
dashboard = [
86+
"surfaces[ml,surrogate-train]",
87+
"streamlit>=1.28.0",
88+
"plotly>=5.0.0",
89+
"pandas>=1.3.0",
90+
]
91+
92+
[project.scripts]
93+
surfaces-dashboard = "surfaces._surrogates._dashboard:run_dashboard"
8494

8595
[project.urls]
8696
"Homepage" = "https://github.com/SimonBlanke/Surfaces"
@@ -125,6 +135,8 @@ ignore = [
125135
"src/surfaces/test_functions/cec/*" = ["F841", "E741"]
126136
# Engineering functions use standard notation (l for length)
127137
"src/surfaces/test_functions/engineering/*" = ["E741", "F841"]
138+
# Dashboard uses Streamlit patterns
139+
"src/surfaces/_surrogates/_dashboard/*" = ["F841"]
128140

129141
[tool.ruff.lint.isort]
130142
known-first-party = ["surfaces"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""
6+
Surrogate Dashboard - Management UI for ML test function surrogates.
7+
8+
This module provides a Streamlit-based dashboard for:
9+
- Viewing all ML test functions and their surrogate status
10+
- Training new or retraining existing surrogates
11+
- Validating surrogate accuracy
12+
- Tracking historical metrics
13+
14+
Usage:
15+
# Via module
16+
python -m surfaces._surrogates._dashboard
17+
18+
# Via CLI (after installing with dashboard extras)
19+
surfaces-dashboard
20+
21+
# Via Python
22+
from surfaces._surrogates._dashboard import run_dashboard
23+
run_dashboard()
24+
25+
Requirements:
26+
pip install surfaces[dashboard]
27+
"""
28+
29+
from pathlib import Path
30+
31+
32+
def run_dashboard():
33+
"""Launch the Streamlit dashboard."""
34+
import subprocess
35+
import sys
36+
37+
app_path = Path(__file__).parent / "app.py"
38+
39+
# Check if streamlit is installed
40+
try:
41+
import streamlit # noqa: F401
42+
except ImportError:
43+
print("Streamlit is required for the dashboard.")
44+
print("Install it with: pip install surfaces[dashboard]")
45+
sys.exit(1)
46+
47+
subprocess.run([sys.executable, "-m", "streamlit", "run", str(app_path), "--server.headless", "true"])
48+
49+
50+
__all__ = ["run_dashboard"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""
6+
Module entry point for the dashboard.
7+
8+
Usage:
9+
python -m surfaces._surrogates._dashboard
10+
"""
11+
12+
from . import run_dashboard
13+
14+
if __name__ == "__main__":
15+
run_dashboard()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Dashboard pages package."""
6+
7+
from . import details, overview, training, validation
8+
9+
__all__ = ["overview", "training", "validation", "details"]
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""
6+
Details page - Deep dive into a single function.
7+
"""
8+
9+
10+
import pandas as pd
11+
import streamlit as st
12+
13+
from surfaces._surrogates._dashboard.database import (
14+
get_all_surrogates,
15+
get_surrogate,
16+
get_training_jobs,
17+
get_validation_runs,
18+
)
19+
20+
21+
def render():
22+
"""Render the details page."""
23+
st.header("Function Details")
24+
25+
# Get all functions
26+
surrogates = get_all_surrogates()
27+
all_names = [s["function_name"] for s in surrogates]
28+
29+
if not all_names:
30+
st.warning("No functions found. Run sync to populate the database.")
31+
return
32+
33+
# Function selector
34+
selected = st.selectbox(
35+
"Select Function",
36+
all_names,
37+
index=0,
38+
key="details_function_select",
39+
)
40+
41+
if not selected:
42+
return
43+
44+
# Get detailed info
45+
surrogate = get_surrogate(selected)
46+
47+
if not surrogate:
48+
st.error(f"Function {selected} not found in database.")
49+
return
50+
51+
st.divider()
52+
53+
# Status banner
54+
if surrogate["has_surrogate"]:
55+
st.success(f"Surrogate model available for {selected}")
56+
else:
57+
st.warning(f"No surrogate model for {selected}")
58+
59+
# Tabs for different info
60+
tab1, tab2, tab3, tab4 = st.tabs([
61+
"Metadata",
62+
"Parameters",
63+
"Training History",
64+
"Validation History",
65+
])
66+
67+
with tab1:
68+
render_metadata(surrogate)
69+
70+
with tab2:
71+
render_parameters(surrogate)
72+
73+
with tab3:
74+
render_training_history(selected)
75+
76+
with tab4:
77+
render_validation_history(selected)
78+
79+
80+
def render_metadata(surrogate: dict):
81+
"""Render metadata section."""
82+
st.subheader("Surrogate Metadata")
83+
84+
col1, col2 = st.columns(2)
85+
86+
with col1:
87+
st.write("**Basic Info**")
88+
st.write(f"- Function Name: `{surrogate['function_name']}`")
89+
st.write(f"- Type: {surrogate['function_type']}")
90+
st.write(f"- Has Surrogate: {'Yes' if surrogate['has_surrogate'] else 'No'}")
91+
st.write(f"- Has Validity Model: {'Yes' if surrogate['has_validity_model'] else 'No'}")
92+
93+
with col2:
94+
st.write("**Training Info**")
95+
if surrogate["has_surrogate"]:
96+
st.write(f"- Training Samples: {surrogate['n_samples'] or 'N/A'}")
97+
st.write(f"- Invalid Samples: {surrogate['n_invalid_samples'] or 0}")
98+
st.write(f"- Training R2: {surrogate['training_r2']:.4f}" if surrogate['training_r2'] else "- Training R2: N/A")
99+
st.write(f"- Training MSE: {surrogate['training_mse']:.6f}" if surrogate['training_mse'] else "- Training MSE: N/A")
100+
st.write(f"- Training Time: {surrogate['training_time_sec']:.1f}s" if surrogate['training_time_sec'] else "- Training Time: N/A")
101+
else:
102+
st.write("No training data available.")
103+
104+
st.divider()
105+
106+
col1, col2 = st.columns(2)
107+
108+
with col1:
109+
st.write("**Value Range**")
110+
if surrogate["y_range_min"] is not None:
111+
st.write(f"- Min: {surrogate['y_range_min']:.4f}")
112+
st.write(f"- Max: {surrogate['y_range_max']:.4f}")
113+
else:
114+
st.write("No range data available.")
115+
116+
with col2:
117+
st.write("**Tracking Info**")
118+
st.write(f"- Last Synced: {surrogate['last_synced_at'][:19] if surrogate['last_synced_at'] else 'Never'}")
119+
st.write(f"- Created: {surrogate['created_at'][:19] if surrogate['created_at'] else 'Unknown'}")
120+
if surrogate["onnx_file_hash"]:
121+
st.write(f"- File Hash: `{surrogate['onnx_file_hash'][:16]}...`")
122+
123+
124+
def render_parameters(surrogate: dict):
125+
"""Render parameters section."""
126+
st.subheader("Model Parameters")
127+
128+
param_names = surrogate.get("param_names", [])
129+
param_encodings = surrogate.get("param_encodings", {})
130+
131+
if not param_names:
132+
st.info("No parameter information available.")
133+
return
134+
135+
st.write(f"**Parameter Count:** {len(param_names)}")
136+
137+
# Parameter table
138+
param_data = []
139+
for name in param_names:
140+
encoding = param_encodings.get(name)
141+
if encoding:
142+
param_type = "categorical"
143+
values = list(encoding.keys())
144+
values_str = ", ".join(values[:5])
145+
if len(values) > 5:
146+
values_str += f" (+{len(values) - 5} more)"
147+
else:
148+
param_type = "numeric"
149+
values_str = "-"
150+
151+
param_data.append({
152+
"Parameter": name,
153+
"Type": param_type,
154+
"Values": values_str,
155+
})
156+
157+
df = pd.DataFrame(param_data)
158+
st.dataframe(df, use_container_width=True, hide_index=True)
159+
160+
# Raw encodings
161+
if param_encodings:
162+
with st.expander("View Raw Encodings"):
163+
st.json(param_encodings)
164+
165+
166+
def render_training_history(function_name: str):
167+
"""Render training history section."""
168+
st.subheader("Training History")
169+
170+
jobs = get_training_jobs(function_name=function_name, limit=20)
171+
172+
if not jobs:
173+
st.info("No training history for this function.")
174+
return
175+
176+
df_data = []
177+
for job in jobs:
178+
duration = "-"
179+
if job["started_at"] and job["completed_at"]:
180+
# Simple duration calculation
181+
try:
182+
from datetime import datetime
183+
start = datetime.fromisoformat(job["started_at"])
184+
end = datetime.fromisoformat(job["completed_at"])
185+
dur_sec = (end - start).total_seconds()
186+
duration = f"{dur_sec:.1f}s"
187+
except Exception:
188+
pass
189+
190+
df_data.append({
191+
"Started": job["started_at"][:19] if job["started_at"] else "-",
192+
"Duration": duration,
193+
"Status": job["status"],
194+
"Triggered By": job["triggered_by"],
195+
"Error": job["error_message"][:30] + "..." if job["error_message"] and len(job["error_message"]) > 30 else (job["error_message"] or "-"),
196+
})
197+
198+
df = pd.DataFrame(df_data)
199+
200+
def highlight_status(val):
201+
if val == "completed":
202+
return "color: #28a745"
203+
elif val == "failed":
204+
return "color: #dc3545"
205+
elif val == "running":
206+
return "color: #17a2b8"
207+
return ""
208+
209+
styled_df = df.style.applymap(highlight_status, subset=["Status"])
210+
st.dataframe(styled_df, use_container_width=True, hide_index=True)
211+
212+
213+
def render_validation_history(function_name: str):
214+
"""Render validation history section."""
215+
st.subheader("Validation History")
216+
217+
runs = get_validation_runs(function_name=function_name, limit=20)
218+
219+
if not runs:
220+
st.info("No validation history for this function.")
221+
return
222+
223+
# Summary chart
224+
if len(runs) > 1:
225+
st.write("**R2 Score Trend**")
226+
227+
# Prepare data for chart
228+
chart_data = []
229+
for run in reversed(runs): # Oldest first for chart
230+
if run["r2_score"] is not None:
231+
chart_data.append({
232+
"Date": run["validated_at"][:10] if run["validated_at"] else "",
233+
"R2": run["r2_score"],
234+
})
235+
236+
if chart_data:
237+
chart_df = pd.DataFrame(chart_data)
238+
st.line_chart(chart_df.set_index("Date"))
239+
240+
# Table
241+
st.write("**Validation Runs**")
242+
243+
df_data = []
244+
for run in runs:
245+
df_data.append({
246+
"Date": run["validated_at"][:19] if run["validated_at"] else "-",
247+
"Type": run["validation_type"],
248+
"Samples": str(run["n_samples"]) if run["n_samples"] else "-",
249+
"R2": f"{run['r2_score']:.4f}" if run["r2_score"] else "-",
250+
"MAE": f"{run['mae']:.4f}" if run["mae"] else "-",
251+
"RMSE": f"{run['rmse']:.4f}" if run["rmse"] else "-",
252+
"Max Error": f"{run['max_error']:.4f}" if run["max_error"] else "-",
253+
"Speedup": f"{run['speedup_factor']:.0f}x" if run["speedup_factor"] else "-",
254+
})
255+
256+
df = pd.DataFrame(df_data)
257+
258+
def highlight_r2(val):
259+
try:
260+
r2 = float(val)
261+
if r2 >= 0.95:
262+
return "color: #28a745"
263+
elif r2 >= 0.90:
264+
return "color: #ffc107"
265+
else:
266+
return "color: #dc3545"
267+
except (ValueError, TypeError):
268+
return ""
269+
270+
styled_df = df.style.applymap(highlight_r2, subset=["R2"])
271+
st.dataframe(styled_df, use_container_width=True, hide_index=True)

0 commit comments

Comments
 (0)