2233# License: MIT License
44
5+ """K-Nearest Neighbors Classifier test function with surrogate support."""
6+
57import numpy as np
68from sklearn .model_selection import cross_val_score
79from sklearn .neighbors import KNeighborsClassifier
810
911from .._base_classification import BaseClassification
1012from ..datasets import digits_data , iris_data , wine_data
1113
14+ # Dataset registry: maps string names to loader functions
15+ DATASETS = {
16+ "digits" : digits_data ,
17+ "iris" : iris_data ,
18+ "wine" : wine_data ,
19+ }
20+
1221
1322class KNeighborsClassifierFunction (BaseClassification ):
1423 """K-Nearest Neighbors Classifier test function.
@@ -18,73 +27,130 @@ class KNeighborsClassifierFunction(BaseClassification):
1827
1928 Parameters
2029 ----------
21- metric : str, default="accuracy"
22- Scoring metric for cross-validation.
30+ dataset : str, default="digits"
31+ Dataset to use for evaluation. One of: "digits", "iris", "wine".
32+ This is a fixed parameter (like a coefficient), not part of the search space.
33+ cv : int, default=5
34+ Number of cross-validation folds.
35+ This is a fixed parameter, not part of the search space.
36+ use_surrogate : bool, default=False
37+ If True, use pre-trained surrogate model for fast evaluation (~1ms).
38+ Falls back to real evaluation if no surrogate is available.
39+ objective : str, default="maximize"
40+ Either "minimize" or "maximize".
2341 sleep : float, default=0
2442 Artificial delay in seconds added to each evaluation.
2543
2644 Attributes
2745 ----------
28- para_names : list
29- Names of the hyperparameters: n_neighbors, algorithm, cv, dataset.
30- n_neighbors_default : list
31- Default values for n_neighbors parameter (3 to 150, step 5).
32- algorithm_default : list
33- Default algorithm options: auto, ball_tree, kd_tree, brute.
34- cv_default : list
35- Default cross-validation fold options: 2, 3, 4, 5, 8, 10.
36- dataset_default : list
37- Default datasets (digits, wine, iris).
46+ available_datasets : list
47+ Available dataset names: ["digits", "iris", "wine"].
48+ available_cv : list
49+ Available CV fold options: [2, 3, 5, 10].
3850
3951 Examples
4052 --------
53+ Basic usage with real evaluation:
54+
4155 >>> from surfaces.test_functions import KNeighborsClassifierFunction
42- >>> func = KNeighborsClassifierFunction()
43- >>> search_space = func.search_space
44- >>> list(search_space.keys())
45- ['n_neighbors', 'algorithm', 'cv', 'dataset']
56+ >>> func = KNeighborsClassifierFunction(dataset="iris", cv=5)
57+ >>> func.search_space
58+ {'n_neighbors': [3, 8, 13, ...], 'algorithm': ['auto', 'ball_tree', ...]}
59+ >>> result = func({"n_neighbors": 5, "algorithm": "auto"})
60+
61+ Fast evaluation with surrogate (requires surfaces[surrogates]):
62+
63+ >>> func = KNeighborsClassifierFunction(dataset="iris", cv=5, use_surrogate=True)
64+ >>> result = func({"n_neighbors": 5, "algorithm": "auto"}) # ~1ms
4665 """
4766
4867 name = "KNeighbors Classifier Function"
4968 _name_ = "k_neighbors_classifier"
5069 __name__ = "KNeighborsClassifierFunction"
5170
52- para_names = ["n_neighbors" , "algorithm" , "cv" , "dataset" ]
71+ # Available options (for validation and documentation)
72+ available_datasets = list (DATASETS .keys ())
73+ available_cv = [2 , 3 , 5 , 10 ]
5374
75+ # Search space parameters (only actual hyperparameters)
76+ para_names = ["n_neighbors" , "algorithm" ]
5477 n_neighbors_default = list (np .arange (3 , 150 , 5 ))
5578 algorithm_default = ["auto" , "ball_tree" , "kd_tree" , "brute" ]
56- cv_default = [2 , 3 , 4 , 5 , 8 , 10 ]
57- dataset_default = [digits_data , wine_data , iris_data ]
58-
59- def __init__ (self , * args , ** kwargs ):
60- super ().__init__ (* args , ** kwargs )
6179
62- def _search_space (
80+ def __init__ (
6381 self ,
64- n_neighbors : list = None ,
65- algorithm : list = None ,
66- cv : list = None ,
67- dataset : list = None ,
82+ dataset : str = "digits" ,
83+ cv : int = 5 ,
84+ objective : str = "maximize" ,
85+ sleep : float = 0 ,
86+ memory : bool = False ,
87+ collect_data : bool = True ,
88+ callbacks = None ,
89+ catch_errors = None ,
90+ use_surrogate : bool = False ,
6891 ):
69- search_space : dict = {}
92+ # Validate dataset
93+ if dataset not in DATASETS :
94+ raise ValueError (
95+ f"Unknown dataset '{ dataset } '. "
96+ f"Available: { self .available_datasets } "
97+ )
98+
99+ # Validate cv
100+ if cv not in self .available_cv :
101+ raise ValueError (
102+ f"Invalid cv={ cv } . Available: { self .available_cv } "
103+ )
70104
71- search_space ["n_neighbors" ] = (
72- self .n_neighbors_default if n_neighbors is None else n_neighbors
105+ # Store fixed parameters (like coefficients in math functions)
106+ self .dataset = dataset
107+ self .cv = cv
108+
109+ # Load dataset for real evaluation
110+ self ._dataset_loader = DATASETS [dataset ]
111+
112+ super ().__init__ (
113+ objective = objective ,
114+ sleep = sleep ,
115+ memory = memory ,
116+ collect_data = collect_data ,
117+ callbacks = callbacks ,
118+ catch_errors = catch_errors ,
119+ use_surrogate = use_surrogate ,
73120 )
74- search_space ["algorithm" ] = self .algorithm_default if algorithm is None else algorithm
75- search_space ["cv" ] = self .cv_default if cv is None else cv
76- search_space ["dataset" ] = self .dataset_default if dataset is None else dataset
77121
78- return search_space
122+ @property
123+ def search_space (self ):
124+ """Search space containing only hyperparameters (not dataset/cv)."""
125+ return {
126+ "n_neighbors" : self .n_neighbors_default ,
127+ "algorithm" : self .algorithm_default ,
128+ }
79129
80130 def _create_objective_function (self ):
131+ """Create objective function with fixed dataset and cv."""
132+ # Load dataset once
133+ X , y = self ._dataset_loader ()
134+ cv = self .cv
135+
81136 def k_neighbors_classifier (params ):
82137 knc = KNeighborsClassifier (
83138 n_neighbors = params ["n_neighbors" ],
84139 algorithm = params ["algorithm" ],
85140 )
86- X , y = params ["dataset" ]()
87- scores = cross_val_score (knc , X , y , cv = params ["cv" ], scoring = "accuracy" )
141+ scores = cross_val_score (knc , X , y , cv = cv , scoring = "accuracy" )
88142 return scores .mean ()
89143
90144 self .pure_objective_function = k_neighbors_classifier
145+
146+ def _get_surrogate_params (self , params ):
147+ """Add fixed parameters (dataset, cv) to params for surrogate prediction.
148+
149+ The surrogate model was trained on all (HP, dataset, cv) combinations,
150+ so we need to include the fixed parameters when querying it.
151+ """
152+ return {
153+ ** params ,
154+ "dataset" : self .dataset ,
155+ "cv" : self .cv ,
156+ }
0 commit comments