Skip to content

Commit a6a570d

Browse files
committed
add models, trainer and registry for surrogates
1 parent 28c5182 commit a6a570d

File tree

7 files changed

+592
-15
lines changed

7 files changed

+592
-15
lines changed

src/surfaces/_surrogates/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
- SurrogateLoader: Load and run pre-trained ONNX surrogate models
1010
- SurrogateTrainer: Train new surrogate models (for maintainers)
1111
- SurrogateValidator: Validate surrogate accuracy against real function
12+
13+
Developer API for ML surrogates:
14+
- train_ml_surrogate: Train surrogate for single ML function
15+
- train_all_ml_surrogates: Train all registered ML surrogates
16+
- train_missing_ml_surrogates: Train only missing surrogates
17+
- list_ml_surrogates: List registered functions and status
1218
"""
1319

1420
from ._surrogate_loader import (
@@ -23,12 +29,28 @@
2329
from ._surrogate_validator import (
2430
SurrogateValidator,
2531
)
32+
from ._ml_surrogate_trainer import (
33+
MLSurrogateTrainer,
34+
train_ml_surrogate,
35+
train_all_ml_surrogates,
36+
train_missing_ml_surrogates,
37+
list_ml_surrogates,
38+
)
2639

2740
__all__ = [
41+
# Loader
2842
"SurrogateLoader",
29-
"SurrogateTrainer",
30-
"SurrogateValidator",
3143
"load_surrogate",
3244
"get_surrogate_path",
45+
# Generic trainer
46+
"SurrogateTrainer",
3347
"train_surrogate_for_function",
48+
# Validator
49+
"SurrogateValidator",
50+
# ML-specific trainer (developer API)
51+
"MLSurrogateTrainer",
52+
"train_ml_surrogate",
53+
"train_all_ml_surrogates",
54+
"train_missing_ml_surrogates",
55+
"list_ml_surrogates",
3456
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Registry of ML functions that support surrogate training.
6+
7+
This module defines which ML functions can have surrogates trained,
8+
along with their fixed parameter grids (dataset, cv combinations).
9+
"""
10+
11+
from typing import Any, Dict, List, Type
12+
13+
# Registry: function_name -> config
14+
ML_SURROGATE_REGISTRY: Dict[str, Dict[str, Any]] = {}
15+
16+
17+
def register_ml_function(
18+
name: str,
19+
function_class: Type,
20+
fixed_params: Dict[str, List],
21+
hyperparams: List[str],
22+
):
23+
"""Register an ML function for surrogate training.
24+
25+
Parameters
26+
----------
27+
name : str
28+
Unique identifier (e.g., "k_neighbors_classifier").
29+
function_class : Type
30+
The function class (e.g., KNeighborsClassifierFunction).
31+
fixed_params : dict
32+
Grid of fixed parameters to iterate over during training.
33+
Example: {"dataset": ["iris", "digits"], "cv": [2, 3, 5, 10]}
34+
hyperparams : list
35+
Names of hyperparameters in the search space.
36+
"""
37+
ML_SURROGATE_REGISTRY[name] = {
38+
"class": function_class,
39+
"fixed_params": fixed_params,
40+
"hyperparams": hyperparams,
41+
}
42+
43+
44+
def get_registered_functions() -> List[str]:
45+
"""Get list of registered function names."""
46+
_ensure_registered()
47+
return list(ML_SURROGATE_REGISTRY.keys())
48+
49+
50+
def get_function_config(name: str) -> Dict[str, Any]:
51+
"""Get configuration for a registered function."""
52+
_ensure_registered()
53+
if name not in ML_SURROGATE_REGISTRY:
54+
raise ValueError(
55+
f"Unknown function '{name}'. "
56+
f"Available: {get_registered_functions()}"
57+
)
58+
return ML_SURROGATE_REGISTRY[name]
59+
60+
61+
# ============================================================================
62+
# Register ML functions (lazy to avoid circular imports)
63+
# ============================================================================
64+
65+
def _ensure_registered():
66+
"""Register all ML functions lazily on first access."""
67+
if ML_SURROGATE_REGISTRY:
68+
return # Already registered
69+
70+
from surfaces.test_functions import (
71+
KNeighborsClassifierFunction,
72+
KNeighborsRegressorFunction,
73+
GradientBoostingRegressorFunction,
74+
)
75+
76+
# Classification functions
77+
register_ml_function(
78+
name="k_neighbors_classifier",
79+
function_class=KNeighborsClassifierFunction,
80+
fixed_params={
81+
"dataset": ["digits", "iris", "wine"],
82+
"cv": [2, 3, 5, 10],
83+
},
84+
hyperparams=["n_neighbors", "algorithm"],
85+
)
86+
87+
# Regression functions
88+
register_ml_function(
89+
name="k_neighbors_regressor",
90+
function_class=KNeighborsRegressorFunction,
91+
fixed_params={
92+
"dataset": ["diabetes", "california"],
93+
"cv": [2, 3, 5, 10],
94+
},
95+
hyperparams=["n_neighbors", "algorithm"],
96+
)
97+
98+
register_ml_function(
99+
name="gradient_boosting_regressor",
100+
function_class=GradientBoostingRegressorFunction,
101+
fixed_params={
102+
"dataset": ["diabetes", "california"],
103+
"cv": [2, 3, 5, 10],
104+
},
105+
hyperparams=["n_estimators", "max_depth"],
106+
)

0 commit comments

Comments
 (0)