@@ -100,7 +100,7 @@ def __init__(
100100
101101 if extra_methods :
102102 for method_metadata in extra_methods :
103- assert method_metadata .method not in methods
103+ assert method_metadata .method not in methods , f" { method_metadata . method } already in methods..."
104104 methods .append (method_metadata .method )
105105 method_metadata_lst .append (method_metadata )
106106
@@ -205,6 +205,54 @@ def generate_repo_holdout(self, method: str) -> Path:
205205 repo .to_dir (path_processed )
206206 return path_processed
207207
208+ # FIXME: This is a hacky approach, refactor
209+ def generate_hpo_trajectories (
210+ self ,
211+ methods : list [str | MethodMetadata ],
212+ n_configs : list [int | None ] | str = "auto" ,
213+ seeds : int | list [int ] = 20 ,
214+ n_iterations : int = 40 ,
215+ default_method : str = None ,
216+ always_include_default : bool = True ,
217+ fit_order : Literal ["original" , "random" ] = "random" ,
218+ time_limit : float | None = None ,
219+ backend : Literal ["ray" , "native" ] = "ray" ,
220+ repo : EvaluationRepository | None = None ,
221+ folds : list [int ] | None = None ,
222+ ta_name : str = None ,
223+ ta_suite : str = None ,
224+ display_name : str = None ,
225+ ) -> pd .DataFrame :
226+ methods : list [MethodMetadata ] = [self .method_metadata (m ) if isinstance (m , str ) else m for m in methods ]
227+ if repo is None :
228+ repo = self .load_repo (methods = methods )
229+ if folds is not None :
230+ repo = repo .subset (folds = folds )
231+ if not default_method :
232+ default_method = methods [0 ]
233+ else :
234+ for method in methods :
235+ if method .method == default_method :
236+ default_method = method
237+ break
238+ hpo_trajectory = default_method .generate_hpo_trajectories (
239+ n_configs = n_configs ,
240+ repo = repo ,
241+ seeds = seeds ,
242+ n_iterations = n_iterations ,
243+ always_include_default = always_include_default ,
244+ fit_order = fit_order ,
245+ time_limit = time_limit ,
246+ backend = backend ,
247+ config_type = repo .config_types (),
248+ cache = False ,
249+ )
250+
251+ hpo_trajectory ["ta_name" ] = ta_name
252+ hpo_trajectory ["ta_suite" ] = ta_suite
253+ hpo_trajectory ["display_name" ] = display_name
254+ return hpo_trajectory
255+
208256 def combine_hpo (
209257 self ,
210258 methods : list [str ],
@@ -213,6 +261,11 @@ def combine_hpo(
213261 ta_suite : str ,
214262 method_default : str | None = None ,
215263 repo : EvaluationRepository | None = None ,
264+ n_configs : int | None = None ,
265+ time_limit : float | None = None ,
266+ fit_order : Literal ["original" , "random" ] = "original" ,
267+ default_always_first : bool = True ,
268+ seed : int = 0 ,
216269 ) -> pd .DataFrame :
217270 """
218271 Perform HPO across multiple methods
@@ -237,16 +290,31 @@ def combine_hpo(
237290 else :
238291 default = None
239292
293+ if default_always_first and config_default :
294+ fixed_configs = [config_default ]
295+ else :
296+ fixed_configs = None
297+
240298 tuned = self .run_hpo (
241299 method = methods ,
242300 repo = repo ,
243301 n_iterations = 1 ,
302+ n_configs = n_configs ,
303+ time_limit = time_limit ,
304+ fit_order = fit_order ,
305+ seed = seed ,
306+ fixed_configs = fixed_configs ,
244307 )
245308
246309 tuned_ens = self .run_hpo (
247310 method = methods ,
248311 repo = repo ,
249312 n_iterations = 40 ,
313+ n_configs = n_configs ,
314+ time_limit = time_limit ,
315+ fit_order = fit_order ,
316+ seed = seed ,
317+ fixed_configs = fixed_configs ,
250318 )
251319
252320 tuned ["ta_name" ] = ta_name
0 commit comments