Skip to content

Commit ead00c4

Browse files
committed
pre-commit
1 parent 8a3650c commit ead00c4

File tree

100 files changed

+1250
-695
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+1250
-695
lines changed

src/surfaces/_surrogates/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
- list_ml_surrogates: List registered functions and status
1818
"""
1919

20+
from ._ml_surrogate_trainer import (
21+
MLSurrogateTrainer,
22+
list_ml_surrogates,
23+
train_all_ml_surrogates,
24+
train_missing_ml_surrogates,
25+
train_ml_surrogate,
26+
)
2027
from ._surrogate_loader import (
2128
SurrogateLoader,
2229
get_surrogate_path,
@@ -29,13 +36,6 @@
2936
from ._surrogate_validator import (
3037
SurrogateValidator,
3138
)
32-
from ._ml_surrogate_trainer import (
33-
MLSurrogateTrainer,
34-
train_ml_surrogate,
35-
train_all_ml_surrogates,
36-
train_missing_ml_surrogates,
37-
list_ml_surrogates,
38-
)
3939

4040
__all__ = [
4141
# Loader

src/surfaces/_surrogates/_dashboard/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def run_dashboard():
4444
print("Install it with: pip install surfaces[dashboard]")
4545
sys.exit(1)
4646

47-
subprocess.run([sys.executable, "-m", "streamlit", "run", str(app_path), "--server.headless", "true"])
47+
subprocess.run(
48+
[sys.executable, "-m", "streamlit", "run", str(app_path), "--server.headless", "true"]
49+
)
4850

4951

5052
__all__ = ["run_dashboard"]

src/surfaces/_surrogates/_dashboard/_pages/details.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Details page - Deep dive into a single function.
77
"""
88

9-
109
import pandas as pd
1110
import streamlit as st
1211

@@ -57,12 +56,14 @@ def render():
5756
st.warning(f"No surrogate model for {selected}")
5857

5958
# Tabs for different info
60-
tab1, tab2, tab3, tab4 = st.tabs([
61-
"Metadata",
62-
"Parameters",
63-
"Training History",
64-
"Validation History",
65-
])
59+
tab1, tab2, tab3, tab4 = st.tabs(
60+
[
61+
"Metadata",
62+
"Parameters",
63+
"Training History",
64+
"Validation History",
65+
]
66+
)
6667

6768
with tab1:
6869
render_metadata(surrogate)
@@ -95,9 +96,21 @@ def render_metadata(surrogate: dict):
9596
if surrogate["has_surrogate"]:
9697
st.write(f"- Training Samples: {surrogate['n_samples'] or 'N/A'}")
9798
st.write(f"- Invalid Samples: {surrogate['n_invalid_samples'] or 0}")
98-
st.write(f"- Training R2: {surrogate['training_r2']:.4f}" if surrogate['training_r2'] else "- Training R2: N/A")
99-
st.write(f"- Training MSE: {surrogate['training_mse']:.6f}" if surrogate['training_mse'] else "- Training MSE: N/A")
100-
st.write(f"- Training Time: {surrogate['training_time_sec']:.1f}s" if surrogate['training_time_sec'] else "- Training Time: N/A")
99+
st.write(
100+
f"- Training R2: {surrogate['training_r2']:.4f}"
101+
if surrogate["training_r2"]
102+
else "- Training R2: N/A"
103+
)
104+
st.write(
105+
f"- Training MSE: {surrogate['training_mse']:.6f}"
106+
if surrogate["training_mse"]
107+
else "- Training MSE: N/A"
108+
)
109+
st.write(
110+
f"- Training Time: {surrogate['training_time_sec']:.1f}s"
111+
if surrogate["training_time_sec"]
112+
else "- Training Time: N/A"
113+
)
101114
else:
102115
st.write("No training data available.")
103116

@@ -115,8 +128,12 @@ def render_metadata(surrogate: dict):
115128

116129
with col2:
117130
st.write("**Tracking Info**")
118-
st.write(f"- Last Synced: {surrogate['last_synced_at'][:19] if surrogate['last_synced_at'] else 'Never'}")
119-
st.write(f"- Created: {surrogate['created_at'][:19] if surrogate['created_at'] else 'Unknown'}")
131+
st.write(
132+
f"- Last Synced: {surrogate['last_synced_at'][:19] if surrogate['last_synced_at'] else 'Never'}"
133+
)
134+
st.write(
135+
f"- Created: {surrogate['created_at'][:19] if surrogate['created_at'] else 'Unknown'}"
136+
)
120137
if surrogate["onnx_file_hash"]:
121138
st.write(f"- File Hash: `{surrogate['onnx_file_hash'][:16]}...`")
122139

@@ -148,11 +165,13 @@ def render_parameters(surrogate: dict):
148165
param_type = "numeric"
149166
values_str = "-"
150167

151-
param_data.append({
152-
"Parameter": name,
153-
"Type": param_type,
154-
"Values": values_str,
155-
})
168+
param_data.append(
169+
{
170+
"Parameter": name,
171+
"Type": param_type,
172+
"Values": values_str,
173+
}
174+
)
156175

157176
df = pd.DataFrame(param_data)
158177
st.dataframe(df, use_container_width=True, hide_index=True)
@@ -180,20 +199,25 @@ def render_training_history(function_name: str):
180199
# Simple duration calculation
181200
try:
182201
from datetime import datetime
202+
183203
start = datetime.fromisoformat(job["started_at"])
184204
end = datetime.fromisoformat(job["completed_at"])
185205
dur_sec = (end - start).total_seconds()
186206
duration = f"{dur_sec:.1f}s"
187207
except Exception:
188208
pass
189209

190-
df_data.append({
191-
"Started": job["started_at"][:19] if job["started_at"] else "-",
192-
"Duration": duration,
193-
"Status": job["status"],
194-
"Triggered By": job["triggered_by"],
195-
"Error": job["error_message"][:30] + "..." if job["error_message"] and len(job["error_message"]) > 30 else (job["error_message"] or "-"),
196-
})
210+
df_data.append(
211+
{
212+
"Started": job["started_at"][:19] if job["started_at"] else "-",
213+
"Duration": duration,
214+
"Status": job["status"],
215+
"Triggered By": job["triggered_by"],
216+
"Error": job["error_message"][:30] + "..."
217+
if job["error_message"] and len(job["error_message"]) > 30
218+
else (job["error_message"] or "-"),
219+
}
220+
)
197221

198222
df = pd.DataFrame(df_data)
199223

@@ -228,10 +252,12 @@ def render_validation_history(function_name: str):
228252
chart_data = []
229253
for run in reversed(runs): # Oldest first for chart
230254
if run["r2_score"] is not None:
231-
chart_data.append({
232-
"Date": run["validated_at"][:10] if run["validated_at"] else "",
233-
"R2": run["r2_score"],
234-
})
255+
chart_data.append(
256+
{
257+
"Date": run["validated_at"][:10] if run["validated_at"] else "",
258+
"R2": run["r2_score"],
259+
}
260+
)
235261

236262
if chart_data:
237263
chart_df = pd.DataFrame(chart_data)
@@ -242,16 +268,18 @@ def render_validation_history(function_name: str):
242268

243269
df_data = []
244270
for run in runs:
245-
df_data.append({
246-
"Date": run["validated_at"][:19] if run["validated_at"] else "-",
247-
"Type": run["validation_type"],
248-
"Samples": str(run["n_samples"]) if run["n_samples"] else "-",
249-
"R2": f"{run['r2_score']:.4f}" if run["r2_score"] else "-",
250-
"MAE": f"{run['mae']:.4f}" if run["mae"] else "-",
251-
"RMSE": f"{run['rmse']:.4f}" if run["rmse"] else "-",
252-
"Max Error": f"{run['max_error']:.4f}" if run["max_error"] else "-",
253-
"Speedup": f"{run['speedup_factor']:.0f}x" if run["speedup_factor"] else "-",
254-
})
271+
df_data.append(
272+
{
273+
"Date": run["validated_at"][:19] if run["validated_at"] else "-",
274+
"Type": run["validation_type"],
275+
"Samples": str(run["n_samples"]) if run["n_samples"] else "-",
276+
"R2": f"{run['r2_score']:.4f}" if run["r2_score"] else "-",
277+
"MAE": f"{run['mae']:.4f}" if run["mae"] else "-",
278+
"RMSE": f"{run['rmse']:.4f}" if run["rmse"] else "-",
279+
"Max Error": f"{run['max_error']:.4f}" if run["max_error"] else "-",
280+
"Speedup": f"{run['speedup_factor']:.0f}x" if run["speedup_factor"] else "-",
281+
}
282+
)
255283

256284
df = pd.DataFrame(df_data)
257285

src/surfaces/_surrogates/_dashboard/_pages/overview.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,13 @@ def render():
9191
"Function Name": lambda x: x["function_name"],
9292
"R2 Score": lambda x: x["latest_r2"] or 0,
9393
"Samples": lambda x: x["n_samples"] or 0,
94-
"Status": lambda x: ["Good", "Needs Attention", "Not Validated", "Missing"].index(x["status"]),
94+
"Status": lambda x: ["Good", "Needs Attention", "Not Validated", "Missing"].index(
95+
x["status"]
96+
),
9597
}
96-
filtered = sorted(filtered, key=sort_key_map[sort_by], reverse=(sort_by == "R2 Score" or sort_by == "Samples"))
98+
filtered = sorted(
99+
filtered, key=sort_key_map[sort_by], reverse=(sort_by == "R2 Score" or sort_by == "Samples")
100+
)
97101

98102
# Summary metrics
99103
st.divider()
@@ -117,15 +121,17 @@ def render():
117121
# Build dataframe for display
118122
df_data = []
119123
for row in filtered:
120-
df_data.append({
121-
"Function": row["function_name"],
122-
"Type": row["function_type"],
123-
"Has Surrogate": "Yes" if row["has_surrogate"] else "No",
124-
"Samples": str(row["n_samples"]) if row["n_samples"] else "-",
125-
"Training R2": f"{row['training_r2']:.4f}" if row["training_r2"] else "-",
126-
"Validation R2": f"{row['latest_r2']:.4f}" if row["latest_r2"] else "-",
127-
"Status": row["status"],
128-
})
124+
df_data.append(
125+
{
126+
"Function": row["function_name"],
127+
"Type": row["function_type"],
128+
"Has Surrogate": "Yes" if row["has_surrogate"] else "No",
129+
"Samples": str(row["n_samples"]) if row["n_samples"] else "-",
130+
"Training R2": f"{row['training_r2']:.4f}" if row["training_r2"] else "-",
131+
"Validation R2": f"{row['latest_r2']:.4f}" if row["latest_r2"] else "-",
132+
"Status": row["status"],
133+
}
134+
)
129135

130136
df = pd.DataFrame(df_data)
131137

@@ -136,10 +142,7 @@ def highlight_status(val):
136142

137143
# Display table
138144
if len(df) > 0:
139-
styled_df = df.style.applymap(
140-
highlight_status,
141-
subset=["Status"]
142-
)
145+
styled_df = df.style.applymap(highlight_status, subset=["Status"])
143146
st.dataframe(
144147
styled_df,
145148
use_container_width=True,

src/surfaces/_surrogates/_dashboard/_pages/training.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def progress_callback(msg: str):
122122
)
123123

124124
with col2:
125-
train_single = st.button("Train", use_container_width=True, type="primary", key="training_train_btn")
125+
train_single = st.button(
126+
"Train", use_container_width=True, type="primary", key="training_train_btn"
127+
)
126128

127129
if train_single and selected_function:
128130
st.divider()
@@ -162,14 +164,18 @@ def progress_callback(msg: str):
162164
if jobs:
163165
df_data = []
164166
for job in jobs:
165-
df_data.append({
166-
"Function": job["function_name"],
167-
"Started": job["started_at"][:19] if job["started_at"] else "-",
168-
"Completed": job["completed_at"][:19] if job["completed_at"] else "-",
169-
"Status": job["status"],
170-
"Triggered By": job["triggered_by"],
171-
"Error": job["error_message"][:50] + "..." if job["error_message"] and len(job["error_message"]) > 50 else (job["error_message"] or "-"),
172-
})
167+
df_data.append(
168+
{
169+
"Function": job["function_name"],
170+
"Started": job["started_at"][:19] if job["started_at"] else "-",
171+
"Completed": job["completed_at"][:19] if job["completed_at"] else "-",
172+
"Status": job["status"],
173+
"Triggered By": job["triggered_by"],
174+
"Error": job["error_message"][:50] + "..."
175+
if job["error_message"] and len(job["error_message"]) > 50
176+
else (job["error_message"] or "-"),
177+
}
178+
)
173179

174180
df = pd.DataFrame(df_data)
175181

src/surfaces/_surrogates/_dashboard/_pages/validation.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def progress_callback(msg: str):
8787
st.text("\n".join(logs[-20:]))
8888

8989
with st.spinner("Validating all surrogates..."):
90-
results = validate_all(
91-
validation_type, n_samples, random_seed, progress_callback
92-
)
90+
results = validate_all(validation_type, n_samples, random_seed, progress_callback)
9391

9492
success_count = sum(1 for r in results if r["success"])
9593
fail_count = len(results) - success_count
@@ -118,7 +116,9 @@ def progress_callback(msg: str):
118116
)
119117

120118
with col2:
121-
validate_single = st.button("Validate", use_container_width=True, type="primary", key="validation_validate_btn")
119+
validate_single = st.button(
120+
"Validate", use_container_width=True, type="primary", key="validation_validate_btn"
121+
)
122122

123123
if validate_single and selected_function:
124124
st.divider()
@@ -167,10 +167,12 @@ def progress_callback(msg: str):
167167
y_surr = data["y_surrogate"]
168168

169169
# Create scatter plot data
170-
scatter_df = pd.DataFrame({
171-
"Actual": y_real,
172-
"Predicted": y_surr,
173-
})
170+
scatter_df = pd.DataFrame(
171+
{
172+
"Actual": y_real,
173+
"Predicted": y_surr,
174+
}
175+
)
174176

175177
# Plot
176178
try:
@@ -188,8 +190,10 @@ def progress_callback(msg: str):
188190
max_val = max(y_real.max(), y_surr.max())
189191
fig.add_shape(
190192
type="line",
191-
x0=min_val, y0=min_val,
192-
x1=max_val, y1=max_val,
193+
x0=min_val,
194+
y0=min_val,
195+
x1=max_val,
196+
y1=max_val,
193197
line=dict(color="red", dash="dash"),
194198
)
195199

@@ -238,15 +242,17 @@ def progress_callback(msg: str):
238242
if runs:
239243
df_data = []
240244
for run in runs:
241-
df_data.append({
242-
"Function": run["function_name"],
243-
"Type": run["validation_type"],
244-
"Samples": str(run["n_samples"]) if run["n_samples"] else "-",
245-
"R2": f"{run['r2_score']:.4f}" if run["r2_score"] else "-",
246-
"MAE": f"{run['mae']:.4f}" if run["mae"] else "-",
247-
"Speedup": f"{run['speedup_factor']:.0f}x" if run["speedup_factor"] else "-",
248-
"Date": run["validated_at"][:19] if run["validated_at"] else "-",
249-
})
245+
df_data.append(
246+
{
247+
"Function": run["function_name"],
248+
"Type": run["validation_type"],
249+
"Samples": str(run["n_samples"]) if run["n_samples"] else "-",
250+
"R2": f"{run['r2_score']:.4f}" if run["r2_score"] else "-",
251+
"MAE": f"{run['mae']:.4f}" if run["mae"] else "-",
252+
"Speedup": f"{run['speedup_factor']:.0f}x" if run["speedup_factor"] else "-",
253+
"Date": run["validated_at"][:19] if run["validated_at"] else "-",
254+
}
255+
)
250256

251257
df = pd.DataFrame(df_data)
252258

0 commit comments

Comments
 (0)