Skip to content

Commit 580b4f5

Browse files
committed
implement 'StackingEnsembleFunction' test-function
1 parent 1a5e3da commit 580b4f5

File tree

1 file changed

+155
-0
lines changed
  • src/surfaces/test_functions/machine_learning/ensemble_optimization/tabular/test_functions

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Stacking Ensemble test function."""
2+
3+
from typing import Any, Callable, Dict, List, Optional, Union
4+
5+
from sklearn.ensemble import (
6+
GradientBoostingClassifier,
7+
RandomForestClassifier,
8+
StackingClassifier,
9+
)
10+
from sklearn.linear_model import LogisticRegression
11+
from sklearn.model_selection import cross_val_score
12+
from sklearn.svm import SVC
13+
from sklearn.tree import DecisionTreeClassifier
14+
15+
from surfaces.modifiers import BaseModifier
16+
from surfaces.test_functions.machine_learning.hyperparameter_optimization.tabular.classification.datasets import (
17+
DATASETS,
18+
)
19+
20+
from .._base_tabular_ensemble import BaseTabularEnsemble
21+
22+
23+
class StackingEnsembleFunction(BaseTabularEnsemble):
24+
"""Stacking Ensemble test function.
25+
26+
Optimizes a stacking ensemble by selecting base learners and the
27+
meta-learner (final estimator). Stacking combines predictions from
28+
multiple models using a meta-model to learn the optimal combination.
29+
30+
Parameters
31+
----------
32+
dataset : str, default="iris"
33+
Dataset to use for evaluation. One of: "digits", "iris", "wine", "breast_cancer".
34+
cv : int, default=5
35+
Number of cross-validation folds.
36+
objective : str, default="maximize"
37+
Either "minimize" or "maximize".
38+
modifiers : list of BaseModifier, optional
39+
List of modifiers to apply to function evaluations.
40+
41+
Examples
42+
--------
43+
>>> from surfaces.test_functions import StackingEnsembleFunction
44+
>>> func = StackingEnsembleFunction(dataset="iris", cv=5)
45+
>>> func.search_space
46+
{'use_dt': [True, False], 'use_rf': [True, False], ...}
47+
>>> result = func({"use_dt": True, "use_rf": True, "use_gb": True,
48+
... "use_svm": False, "final_estimator": "lr"})
49+
"""
50+
51+
name = "Stacking Ensemble"
52+
_name_ = "stacking_ensemble"
53+
__name__ = "StackingEnsembleFunction"
54+
55+
available_datasets = ["digits", "iris", "wine", "breast_cancer"]
56+
available_cv = [2, 3, 5, 10]
57+
58+
para_names = ["use_dt", "use_rf", "use_gb", "use_svm", "final_estimator"]
59+
use_dt_default = [True, False]
60+
use_rf_default = [True, False]
61+
use_gb_default = [True, False]
62+
use_svm_default = [True, False]
63+
final_estimator_default = ["lr", "rf", "gb"]
64+
65+
def __init__(
66+
self,
67+
dataset: str = "iris",
68+
cv: int = 5,
69+
objective: str = "maximize",
70+
modifiers: Optional[List[BaseModifier]] = None,
71+
memory: bool = False,
72+
collect_data: bool = True,
73+
callbacks: Optional[Union[Callable, List[Callable]]] = None,
74+
catch_errors: Optional[Dict[type, float]] = None,
75+
use_surrogate: bool = False,
76+
):
77+
if dataset not in self.available_datasets:
78+
raise ValueError(f"Unknown dataset '{dataset}'. Available: {self.available_datasets}")
79+
80+
if cv not in self.available_cv:
81+
raise ValueError(f"Invalid cv={cv}. Available: {self.available_cv}")
82+
83+
self.dataset = dataset
84+
self.cv = cv
85+
self._dataset_loader = DATASETS[dataset]
86+
87+
super().__init__(
88+
objective=objective,
89+
modifiers=modifiers,
90+
memory=memory,
91+
collect_data=collect_data,
92+
callbacks=callbacks,
93+
catch_errors=catch_errors,
94+
use_surrogate=use_surrogate,
95+
)
96+
97+
@property
98+
def search_space(self) -> Dict[str, Any]:
99+
"""Search space for stacking ensemble optimization."""
100+
return {
101+
"use_dt": self.use_dt_default,
102+
"use_rf": self.use_rf_default,
103+
"use_gb": self.use_gb_default,
104+
"use_svm": self.use_svm_default,
105+
"final_estimator": self.final_estimator_default,
106+
}
107+
108+
def _create_objective_function(self) -> None:
109+
"""Create objective function for stacking ensemble."""
110+
X, y = self._dataset_loader()
111+
cv = self.cv
112+
113+
def objective_function(params: Dict[str, Any]) -> float:
114+
# Build base estimators
115+
estimators = []
116+
117+
if params["use_dt"]:
118+
estimators.append(("dt", DecisionTreeClassifier(random_state=42)))
119+
120+
if params["use_rf"]:
121+
estimators.append(("rf", RandomForestClassifier(n_estimators=50, random_state=42)))
122+
123+
if params["use_gb"]:
124+
estimators.append(
125+
("gb", GradientBoostingClassifier(n_estimators=50, random_state=42))
126+
)
127+
128+
if params["use_svm"]:
129+
estimators.append(("svm", SVC(probability=True, random_state=42)))
130+
131+
# Need at least 2 base models for stacking
132+
if len(estimators) < 2:
133+
return 0.0
134+
135+
# Select final estimator (meta-learner)
136+
final_est_type = params["final_estimator"]
137+
if final_est_type == "lr":
138+
final_estimator = LogisticRegression(max_iter=1000, random_state=42)
139+
elif final_est_type == "rf":
140+
final_estimator = RandomForestClassifier(n_estimators=50, random_state=42)
141+
elif final_est_type == "gb":
142+
final_estimator = GradientBoostingClassifier(n_estimators=50, random_state=42)
143+
else:
144+
raise ValueError(f"Unknown final_estimator: {final_est_type}")
145+
146+
# Create stacking classifier
147+
ensemble = StackingClassifier(
148+
estimators=estimators, final_estimator=final_estimator, cv=3
149+
)
150+
151+
# Evaluate
152+
scores = cross_val_score(ensemble, X, y, cv=cv, scoring="accuracy")
153+
return scores.mean()
154+
155+
self.pure_objective_function = objective_function

0 commit comments

Comments
 (0)