Skip to content

Commit f6eed6c

Browse files
committed
add ml-test-functions and datasets for tabular ml
1 parent c879d34 commit f6eed6c

File tree

17 files changed

+928
-37
lines changed

17 files changed

+928
-37
lines changed

src/surfaces/_surrogates/_ml_registry.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,39 +68,97 @@ def _ensure_registered():
6868
return # Already registered
6969

7070
from surfaces.test_functions import (
71+
# Classification
72+
DecisionTreeClassifierFunction,
73+
GradientBoostingClassifierFunction,
7174
KNeighborsClassifierFunction,
72-
KNeighborsRegressorFunction,
75+
RandomForestClassifierFunction,
76+
SVMClassifierFunction,
77+
# Regression
78+
DecisionTreeRegressorFunction,
7379
GradientBoostingRegressorFunction,
80+
KNeighborsRegressorFunction,
81+
RandomForestRegressorFunction,
82+
SVMRegressorFunction,
7483
)
7584

85+
# Dataset grids
86+
classification_datasets = ["digits", "iris", "wine", "breast_cancer", "covtype"]
87+
regression_datasets = ["diabetes", "california", "friedman1", "friedman2", "linear"]
88+
cv_options = [2, 3, 5, 10]
89+
90+
# =========================================================================
7691
# Classification functions
92+
# =========================================================================
93+
register_ml_function(
94+
name="decision_tree_classifier",
95+
function_class=DecisionTreeClassifierFunction,
96+
fixed_params={"dataset": classification_datasets, "cv": cv_options},
97+
hyperparams=["max_depth", "min_samples_split", "min_samples_leaf"],
98+
)
99+
100+
register_ml_function(
101+
name="gradient_boosting_classifier",
102+
function_class=GradientBoostingClassifierFunction,
103+
fixed_params={"dataset": classification_datasets, "cv": cv_options},
104+
hyperparams=["n_estimators", "max_depth", "learning_rate"],
105+
)
106+
77107
register_ml_function(
78108
name="k_neighbors_classifier",
79109
function_class=KNeighborsClassifierFunction,
80-
fixed_params={
81-
"dataset": ["digits", "iris", "wine"],
82-
"cv": [2, 3, 5, 10],
83-
},
110+
fixed_params={"dataset": classification_datasets, "cv": cv_options},
84111
hyperparams=["n_neighbors", "algorithm"],
85112
)
86113

114+
register_ml_function(
115+
name="random_forest_classifier",
116+
function_class=RandomForestClassifierFunction,
117+
fixed_params={"dataset": classification_datasets, "cv": cv_options},
118+
hyperparams=["n_estimators", "max_depth", "min_samples_split"],
119+
)
120+
121+
register_ml_function(
122+
name="svm_classifier",
123+
function_class=SVMClassifierFunction,
124+
fixed_params={"dataset": classification_datasets, "cv": cv_options},
125+
hyperparams=["C", "kernel", "gamma"],
126+
)
127+
128+
# =========================================================================
87129
# Regression functions
130+
# =========================================================================
88131
register_ml_function(
89-
name="k_neighbors_regressor",
90-
function_class=KNeighborsRegressorFunction,
91-
fixed_params={
92-
"dataset": ["diabetes", "california"],
93-
"cv": [2, 3, 5, 10],
94-
},
95-
hyperparams=["n_neighbors", "algorithm"],
132+
name="decision_tree_regressor",
133+
function_class=DecisionTreeRegressorFunction,
134+
fixed_params={"dataset": regression_datasets, "cv": cv_options},
135+
hyperparams=["max_depth", "min_samples_split", "min_samples_leaf"],
96136
)
97137

98138
register_ml_function(
99139
name="gradient_boosting_regressor",
100140
function_class=GradientBoostingRegressorFunction,
101-
fixed_params={
102-
"dataset": ["diabetes", "california"],
103-
"cv": [2, 3, 5, 10],
104-
},
141+
fixed_params={"dataset": regression_datasets, "cv": cv_options},
105142
hyperparams=["n_estimators", "max_depth"],
106143
)
144+
145+
register_ml_function(
146+
name="k_neighbors_regressor",
147+
function_class=KNeighborsRegressorFunction,
148+
fixed_params={"dataset": regression_datasets, "cv": cv_options},
149+
hyperparams=["n_neighbors", "algorithm"],
150+
)
151+
152+
register_ml_function(
153+
name="random_forest_regressor",
154+
function_class=RandomForestRegressorFunction,
155+
fixed_params={"dataset": regression_datasets, "cv": cv_options},
156+
hyperparams=["n_estimators", "max_depth", "min_samples_split"],
157+
)
158+
159+
register_ml_function(
160+
name="svm_regressor",
161+
function_class=SVMRegressorFunction,
162+
fixed_params={"dataset": regression_datasets, "cv": cv_options},
163+
hyperparams=["C", "kernel", "gamma"],
164+
)

src/surfaces/test_functions/machine_learning/__init__.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,48 @@ def _check_sklearn():
2121

2222
if _HAS_SKLEARN:
2323
from .tabular import (
24-
GradientBoostingRegressorFunction,
24+
# Classification
25+
DecisionTreeClassifierFunction,
26+
GradientBoostingClassifierFunction,
2527
KNeighborsClassifierFunction,
28+
RandomForestClassifierFunction,
29+
SVMClassifierFunction,
30+
# Regression
31+
DecisionTreeRegressorFunction,
32+
GradientBoostingRegressorFunction,
2633
KNeighborsRegressorFunction,
34+
RandomForestRegressorFunction,
35+
SVMRegressorFunction,
2736
)
2837

2938
__all__ = [
39+
# Classification
40+
"DecisionTreeClassifierFunction",
41+
"GradientBoostingClassifierFunction",
3042
"KNeighborsClassifierFunction",
31-
"KNeighborsRegressorFunction",
43+
"RandomForestClassifierFunction",
44+
"SVMClassifierFunction",
45+
# Regression
46+
"DecisionTreeRegressorFunction",
3247
"GradientBoostingRegressorFunction",
48+
"KNeighborsRegressorFunction",
49+
"RandomForestRegressorFunction",
50+
"SVMRegressorFunction",
3351
]
3452

3553
machine_learning_functions = [
54+
# Classification
55+
DecisionTreeClassifierFunction,
56+
GradientBoostingClassifierFunction,
3657
KNeighborsClassifierFunction,
58+
RandomForestClassifierFunction,
59+
SVMClassifierFunction,
60+
# Regression
61+
DecisionTreeRegressorFunction,
3762
GradientBoostingRegressorFunction,
3863
KNeighborsRegressorFunction,
64+
RandomForestRegressorFunction,
65+
SVMRegressorFunction,
3966
]
4067
else:
4168
__all__ = []

src/surfaces/test_functions/machine_learning/tabular/__init__.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,32 @@
22
33
# License: MIT License
44

5-
from .classification import KNeighborsClassifierFunction
6-
from .regression import GradientBoostingRegressorFunction, KNeighborsRegressorFunction
5+
from .classification import (
6+
DecisionTreeClassifierFunction,
7+
GradientBoostingClassifierFunction,
8+
KNeighborsClassifierFunction,
9+
RandomForestClassifierFunction,
10+
SVMClassifierFunction,
11+
)
12+
from .regression import (
13+
DecisionTreeRegressorFunction,
14+
GradientBoostingRegressorFunction,
15+
KNeighborsRegressorFunction,
16+
RandomForestRegressorFunction,
17+
SVMRegressorFunction,
18+
)
719

820
__all__ = [
21+
# Classification
22+
"DecisionTreeClassifierFunction",
23+
"GradientBoostingClassifierFunction",
924
"KNeighborsClassifierFunction",
10-
"KNeighborsRegressorFunction",
25+
"RandomForestClassifierFunction",
26+
"SVMClassifierFunction",
27+
# Regression
28+
"DecisionTreeRegressorFunction",
1129
"GradientBoostingRegressorFunction",
30+
"KNeighborsRegressorFunction",
31+
"RandomForestRegressorFunction",
32+
"SVMRegressorFunction",
1233
]

src/surfaces/test_functions/machine_learning/tabular/classification/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@
33
# License: MIT License
44

55
from .test_functions import (
6+
DecisionTreeClassifierFunction,
7+
GradientBoostingClassifierFunction,
68
KNeighborsClassifierFunction,
9+
RandomForestClassifierFunction,
10+
SVMClassifierFunction,
711
)
812

913
__all__ = [
14+
"DecisionTreeClassifierFunction",
15+
"GradientBoostingClassifierFunction",
1016
"KNeighborsClassifierFunction",
17+
"RandomForestClassifierFunction",
18+
"SVMClassifierFunction",
1119
]

src/surfaces/test_functions/machine_learning/tabular/classification/datasets.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,63 @@
22
33
# License: MIT License
44

5-
from sklearn.datasets import load_digits, load_iris, load_wine
5+
"""Classification datasets for ML test functions."""
66

7+
from sklearn.datasets import (
8+
load_digits,
9+
load_iris,
10+
load_wine,
11+
load_breast_cancer,
12+
fetch_covtype,
13+
)
14+
15+
# Pre-load small datasets for fast access
716
digits_dataset = load_digits()
8-
wine_dataset = load_wine()
917
iris_dataset = load_iris()
18+
wine_dataset = load_wine()
19+
breast_cancer_dataset = load_breast_cancer()
20+
_covtype_dataset = None # Lazy load (larger download)
1021

1122

1223
def digits_data():
24+
"""Load digits dataset (1797 samples, 64 features, 10 classes)."""
1325
return digits_dataset.data, digits_dataset.target
1426

1527

28+
def iris_data():
29+
"""Load iris dataset (150 samples, 4 features, 3 classes)."""
30+
return iris_dataset.data, iris_dataset.target
31+
32+
1633
def wine_data():
34+
"""Load wine dataset (178 samples, 13 features, 3 classes)."""
1735
return wine_dataset.data, wine_dataset.target
1836

1937

20-
def iris_data():
21-
return iris_dataset.data, iris_dataset.target
38+
def breast_cancer_data():
39+
"""Load breast cancer dataset (569 samples, 30 features, 2 classes)."""
40+
return breast_cancer_dataset.data, breast_cancer_dataset.target
41+
42+
43+
def covtype_data():
44+
"""Load covertype dataset (581012 samples, 54 features, 7 classes).
45+
46+
Note: This is a large dataset. First call triggers download.
47+
For faster training, a 10% subsample is returned.
48+
"""
49+
global _covtype_dataset
50+
if _covtype_dataset is None:
51+
_covtype_dataset = fetch_covtype()
52+
# Return 10% subsample for reasonable training time
53+
n_samples = len(_covtype_dataset.target) // 10
54+
return _covtype_dataset.data[:n_samples], _covtype_dataset.target[:n_samples]
55+
56+
57+
# Registry for easy access
58+
DATASETS = {
59+
"digits": digits_data,
60+
"iris": iris_data,
61+
"wine": wine_data,
62+
"breast_cancer": breast_cancer_data,
63+
"covtype": covtype_data,
64+
}

src/surfaces/test_functions/machine_learning/tabular/classification/test_functions/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@
33
# License: MIT License
44

55

6+
from .decision_tree_classifier import DecisionTreeClassifierFunction
7+
from .gradient_boosting_classifier import GradientBoostingClassifierFunction
68
from .k_neighbors_classifier import KNeighborsClassifierFunction
9+
from .random_forest_classifier import RandomForestClassifierFunction
10+
from .svm_classifier import SVMClassifierFunction
711

812
__all__ = [
13+
"DecisionTreeClassifierFunction",
14+
"GradientBoostingClassifierFunction",
915
"KNeighborsClassifierFunction",
16+
"RandomForestClassifierFunction",
17+
"SVMClassifierFunction",
1018
]

0 commit comments

Comments
 (0)