Skip to content

Commit c879d34

Browse files
committed
change API and usage of ml-test-functions
1 parent a6a570d commit c879d34

File tree

5 files changed

+327
-115
lines changed

5 files changed

+327
-115
lines changed

src/surfaces/test_functions/machine_learning/_base_machine_learning.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def __init__(
5656
use_surrogate: bool = False,
5757
**kwargs,
5858
):
59-
super().__init__(objective, sleep, memory, collect_data, callbacks, catch_errors)
59+
super().__init__(
60+
objective, sleep, memory, collect_data, callbacks, catch_errors
61+
)
6062
self.use_surrogate = use_surrogate
6163
self._surrogate = None
6264

@@ -77,6 +79,24 @@ def _load_surrogate(self) -> None:
7779
)
7880
self.use_surrogate = False
7981

82+
def _get_surrogate_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
83+
"""Get parameters for surrogate prediction.
84+
85+
Override in subclasses to add fixed parameters (like dataset, cv)
86+
that are not in the search space but needed by the surrogate.
87+
88+
Parameters
89+
----------
90+
params : dict
91+
Search parameters from the optimizer.
92+
93+
Returns
94+
-------
95+
dict
96+
Full parameters for surrogate prediction.
97+
"""
98+
return params
99+
80100
def _evaluate(self, params: Dict[str, Any]) -> float:
81101
"""Evaluate with timing and objective transformation.
82102
@@ -86,7 +106,9 @@ def _evaluate(self, params: Dict[str, Any]) -> float:
86106
time.sleep(self.sleep)
87107

88108
if self.use_surrogate and self._surrogate is not None:
89-
raw_value = self._surrogate.predict(params)
109+
# Use _get_surrogate_params to include fixed params (dataset, cv)
110+
surrogate_params = self._get_surrogate_params(params)
111+
raw_value = self._surrogate.predict(surrogate_params)
90112
else:
91113
raw_value = self.pure_objective_function(params)
92114

src/surfaces/test_functions/machine_learning/tabular/classification/test_functions/k_neighbors_classifier.py

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22
33
# License: MIT License
44

5+
"""K-Nearest Neighbors Classifier test function with surrogate support."""
6+
57
import numpy as np
68
from sklearn.model_selection import cross_val_score
79
from sklearn.neighbors import KNeighborsClassifier
810

911
from .._base_classification import BaseClassification
1012
from ..datasets import digits_data, iris_data, wine_data
1113

14+
# Dataset registry: maps string names to loader functions
15+
DATASETS = {
16+
"digits": digits_data,
17+
"iris": iris_data,
18+
"wine": wine_data,
19+
}
20+
1221

1322
class KNeighborsClassifierFunction(BaseClassification):
1423
"""K-Nearest Neighbors Classifier test function.
@@ -18,73 +27,130 @@ class KNeighborsClassifierFunction(BaseClassification):
1827
1928
Parameters
2029
----------
21-
metric : str, default="accuracy"
22-
Scoring metric for cross-validation.
30+
dataset : str, default="digits"
31+
Dataset to use for evaluation. One of: "digits", "iris", "wine".
32+
This is a fixed parameter (like a coefficient), not part of the search space.
33+
cv : int, default=5
34+
Number of cross-validation folds.
35+
This is a fixed parameter, not part of the search space.
36+
use_surrogate : bool, default=False
37+
If True, use pre-trained surrogate model for fast evaluation (~1ms).
38+
Falls back to real evaluation if no surrogate is available.
39+
objective : str, default="maximize"
40+
Either "minimize" or "maximize".
2341
sleep : float, default=0
2442
Artificial delay in seconds added to each evaluation.
2543
2644
Attributes
2745
----------
28-
para_names : list
29-
Names of the hyperparameters: n_neighbors, algorithm, cv, dataset.
30-
n_neighbors_default : list
31-
Default values for n_neighbors parameter (3 to 150, step 5).
32-
algorithm_default : list
33-
Default algorithm options: auto, ball_tree, kd_tree, brute.
34-
cv_default : list
35-
Default cross-validation fold options: 2, 3, 4, 5, 8, 10.
36-
dataset_default : list
37-
Default datasets (digits, wine, iris).
46+
available_datasets : list
47+
Available dataset names: ["digits", "iris", "wine"].
48+
available_cv : list
49+
Available CV fold options: [2, 3, 5, 10].
3850
3951
Examples
4052
--------
53+
Basic usage with real evaluation:
54+
4155
>>> from surfaces.test_functions import KNeighborsClassifierFunction
42-
>>> func = KNeighborsClassifierFunction()
43-
>>> search_space = func.search_space
44-
>>> list(search_space.keys())
45-
['n_neighbors', 'algorithm', 'cv', 'dataset']
56+
>>> func = KNeighborsClassifierFunction(dataset="iris", cv=5)
57+
>>> func.search_space
58+
{'n_neighbors': [3, 8, 13, ...], 'algorithm': ['auto', 'ball_tree', ...]}
59+
>>> result = func({"n_neighbors": 5, "algorithm": "auto"})
60+
61+
Fast evaluation with surrogate (requires surfaces[surrogates]):
62+
63+
>>> func = KNeighborsClassifierFunction(dataset="iris", cv=5, use_surrogate=True)
64+
>>> result = func({"n_neighbors": 5, "algorithm": "auto"}) # ~1ms
4665
"""
4766

4867
name = "KNeighbors Classifier Function"
4968
_name_ = "k_neighbors_classifier"
5069
__name__ = "KNeighborsClassifierFunction"
5170

52-
para_names = ["n_neighbors", "algorithm", "cv", "dataset"]
71+
# Available options (for validation and documentation)
72+
available_datasets = list(DATASETS.keys())
73+
available_cv = [2, 3, 5, 10]
5374

75+
# Search space parameters (only actual hyperparameters)
76+
para_names = ["n_neighbors", "algorithm"]
5477
n_neighbors_default = list(np.arange(3, 150, 5))
5578
algorithm_default = ["auto", "ball_tree", "kd_tree", "brute"]
56-
cv_default = [2, 3, 4, 5, 8, 10]
57-
dataset_default = [digits_data, wine_data, iris_data]
58-
59-
def __init__(self, *args, **kwargs):
60-
super().__init__(*args, **kwargs)
6179

62-
def _search_space(
80+
def __init__(
6381
self,
64-
n_neighbors: list = None,
65-
algorithm: list = None,
66-
cv: list = None,
67-
dataset: list = None,
82+
dataset: str = "digits",
83+
cv: int = 5,
84+
objective: str = "maximize",
85+
sleep: float = 0,
86+
memory: bool = False,
87+
collect_data: bool = True,
88+
callbacks=None,
89+
catch_errors=None,
90+
use_surrogate: bool = False,
6891
):
69-
search_space: dict = {}
92+
# Validate dataset
93+
if dataset not in DATASETS:
94+
raise ValueError(
95+
f"Unknown dataset '{dataset}'. "
96+
f"Available: {self.available_datasets}"
97+
)
98+
99+
# Validate cv
100+
if cv not in self.available_cv:
101+
raise ValueError(
102+
f"Invalid cv={cv}. Available: {self.available_cv}"
103+
)
70104

71-
search_space["n_neighbors"] = (
72-
self.n_neighbors_default if n_neighbors is None else n_neighbors
105+
# Store fixed parameters (like coefficients in math functions)
106+
self.dataset = dataset
107+
self.cv = cv
108+
109+
# Load dataset for real evaluation
110+
self._dataset_loader = DATASETS[dataset]
111+
112+
super().__init__(
113+
objective=objective,
114+
sleep=sleep,
115+
memory=memory,
116+
collect_data=collect_data,
117+
callbacks=callbacks,
118+
catch_errors=catch_errors,
119+
use_surrogate=use_surrogate,
73120
)
74-
search_space["algorithm"] = self.algorithm_default if algorithm is None else algorithm
75-
search_space["cv"] = self.cv_default if cv is None else cv
76-
search_space["dataset"] = self.dataset_default if dataset is None else dataset
77121

78-
return search_space
122+
@property
123+
def search_space(self):
124+
"""Search space containing only hyperparameters (not dataset/cv)."""
125+
return {
126+
"n_neighbors": self.n_neighbors_default,
127+
"algorithm": self.algorithm_default,
128+
}
79129

80130
def _create_objective_function(self):
131+
"""Create objective function with fixed dataset and cv."""
132+
# Load dataset once
133+
X, y = self._dataset_loader()
134+
cv = self.cv
135+
81136
def k_neighbors_classifier(params):
82137
knc = KNeighborsClassifier(
83138
n_neighbors=params["n_neighbors"],
84139
algorithm=params["algorithm"],
85140
)
86-
X, y = params["dataset"]()
87-
scores = cross_val_score(knc, X, y, cv=params["cv"], scoring="accuracy")
141+
scores = cross_val_score(knc, X, y, cv=cv, scoring="accuracy")
88142
return scores.mean()
89143

90144
self.pure_objective_function = k_neighbors_classifier
145+
146+
def _get_surrogate_params(self, params):
147+
"""Add fixed parameters (dataset, cv) to params for surrogate prediction.
148+
149+
The surrogate model was trained on all (HP, dataset, cv) combinations,
150+
so we need to include the fixed parameters when querying it.
151+
"""
152+
return {
153+
**params,
154+
"dataset": self.dataset,
155+
"cv": self.cv,
156+
}

src/surfaces/test_functions/machine_learning/tabular/regression/datasets.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,30 @@
22
33
# License: MIT License
44

5-
from sklearn.datasets import load_diabetes
5+
"""Regression datasets for ML test functions."""
66

7+
from sklearn.datasets import load_diabetes, fetch_california_housing
8+
9+
# Pre-load datasets for fast access
710
diabetes_dataset = load_diabetes()
11+
_california_dataset = None # Lazy load (larger download)
812

913

1014
def diabetes_data():
15+
"""Load diabetes dataset (442 samples, 10 features)."""
1116
return diabetes_dataset.data, diabetes_dataset.target
17+
18+
19+
def california_data():
20+
"""Load California housing dataset (20640 samples, 8 features)."""
21+
global _california_dataset
22+
if _california_dataset is None:
23+
_california_dataset = fetch_california_housing()
24+
return _california_dataset.data, _california_dataset.target
25+
26+
27+
# Registry for easy access
28+
DATASETS = {
29+
"diabetes": diabetes_data,
30+
"california": california_data,
31+
}

0 commit comments

Comments
 (0)