Skip to content

Commit 6bb915e

Browse files
committed
Add heterogeneous HPO simulation support
1 parent 0205ebd commit 6bb915e

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

tabarena/tabarena/nips2025_utils/artifacts/method_metadata.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,15 +628,17 @@ def generate_hpo_result(
628628
time_limit: float | None = None,
629629
fixed_configs: list[str] | None = None,
630630
fit_order: Literal["original", "random"] = "random",
631+
config_type: str | list[str] | None = None,
631632
holdout: bool = False,
632633
backend: Literal["ray", "native"] = "ray",
633634
seed: int = 0,
634635
**kwargs,
635636
) -> pd.DataFrame:
636637
if repo is None:
637638
repo = self.load_processed(as_holdout=holdout)
638-
assert self.config_type is not None
639-
config_type = self.config_type
639+
if config_type is None:
640+
assert self.config_type is not None
641+
config_type = self.config_type
640642
simulator = PaperRunTabArena(repo=repo, backend=backend)
641643
df_results_hpo = simulator.run_ensemble_config_type(
642644
config_type=config_type,
@@ -651,7 +653,7 @@ def generate_hpo_result(
651653
df_results_hpo = df_results_hpo.rename(columns={
652654
"framework": "method",
653655
})
654-
df_results_hpo["method"] = f"HPO-N{n_configs}-{self.config_type}"
656+
df_results_hpo["method"] = f"HPO-N{n_configs}-{config_type}"
655657
df_results_hpo["n_configs"] = n_configs
656658
df_results_hpo["n_iterations"] = n_iterations
657659
df_results_hpo["seed"] = seed
@@ -669,6 +671,8 @@ def generate_hpo_trajectories(
669671
time_limit: float | None = None,
670672
backend: Literal["ray", "native"] = "ray",
671673
holdout: bool = False,
674+
config_type: str | list[str] | None = None,
675+
repo: EvaluationRepository | None = None,
672676
cache: bool = False,
673677
) -> pd.DataFrame:
674678
if n_configs == "auto":
@@ -687,7 +691,8 @@ def generate_hpo_trajectories(
687691
seeds = [i for i in range(seeds)]
688692

689693
df_results_hpo_lst = []
690-
repo = self.load_processed(as_holdout=holdout)
694+
if repo is None:
695+
repo = self.load_processed(as_holdout=holdout)
691696

692697
# FIXME: Breaks for holdout, need to find a way to get self.config_default(holdout=True)
693698
# FIXME: Needed for TabPFN-2.5
@@ -719,6 +724,7 @@ def generate_hpo_trajectories(
719724
time_limit=time_limit,
720725
backend=backend,
721726
holdout=holdout,
727+
config_type=config_type,
722728
)
723729
df_results_hpo["always_include_default"] = always_include_default
724730
df_results_hpo_lst.append(df_results_hpo)

tabarena/tabarena/nips2025_utils/tabarena_context.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100

101101
if extra_methods:
102102
for method_metadata in extra_methods:
103-
assert method_metadata.method not in methods
103+
assert method_metadata.method not in methods, f"{method_metadata.method} already in methods..."
104104
methods.append(method_metadata.method)
105105
method_metadata_lst.append(method_metadata)
106106

@@ -205,6 +205,54 @@ def generate_repo_holdout(self, method: str) -> Path:
205205
repo.to_dir(path_processed)
206206
return path_processed
207207

208+
# FIXME: This is a hacky approach, refactor
209+
def generate_hpo_trajectories(
210+
self,
211+
methods: list[str | MethodMetadata],
212+
n_configs: list[int | None] | str = "auto",
213+
seeds: int | list[int] = 20,
214+
n_iterations: int = 40,
215+
default_method: str = None,
216+
always_include_default: bool = True,
217+
fit_order: Literal["original", "random"] = "random",
218+
time_limit: float | None = None,
219+
backend: Literal["ray", "native"] = "ray",
220+
repo: EvaluationRepository | None = None,
221+
folds: list[int] | None = None,
222+
ta_name: str = None,
223+
ta_suite: str = None,
224+
display_name: str = None,
225+
) -> pd.DataFrame:
226+
methods: list[MethodMetadata] = [self.method_metadata(m) if isinstance(m, str) else m for m in methods]
227+
if repo is None:
228+
repo = self.load_repo(methods=methods)
229+
if folds is not None:
230+
repo = repo.subset(folds=folds)
231+
if not default_method:
232+
default_method = methods[0]
233+
else:
234+
for method in methods:
235+
if method.method == default_method:
236+
default_method = method
237+
break
238+
hpo_trajectory = default_method.generate_hpo_trajectories(
239+
n_configs=n_configs,
240+
repo=repo,
241+
seeds=seeds,
242+
n_iterations=n_iterations,
243+
always_include_default=always_include_default,
244+
fit_order=fit_order,
245+
time_limit=time_limit,
246+
backend=backend,
247+
config_type=repo.config_types(),
248+
cache=False,
249+
)
250+
251+
hpo_trajectory["ta_name"] = ta_name
252+
hpo_trajectory["ta_suite"] = ta_suite
253+
hpo_trajectory["display_name"] = display_name
254+
return hpo_trajectory
255+
208256
def combine_hpo(
209257
self,
210258
methods: list[str],
@@ -213,6 +261,11 @@ def combine_hpo(
213261
ta_suite: str,
214262
method_default: str | None = None,
215263
repo: EvaluationRepository | None = None,
264+
n_configs: int | None = None,
265+
time_limit: float | None = None,
266+
fit_order: Literal["original", "random"] = "original",
267+
default_always_first: bool = True,
268+
seed: int = 0,
216269
) -> pd.DataFrame:
217270
"""
218271
Perform HPO across multiple methods
@@ -237,16 +290,31 @@ def combine_hpo(
237290
else:
238291
default = None
239292

293+
if default_always_first and config_default:
294+
fixed_configs = [config_default]
295+
else:
296+
fixed_configs = None
297+
240298
tuned = self.run_hpo(
241299
method=methods,
242300
repo=repo,
243301
n_iterations=1,
302+
n_configs=n_configs,
303+
time_limit=time_limit,
304+
fit_order=fit_order,
305+
seed=seed,
306+
fixed_configs=fixed_configs,
244307
)
245308

246309
tuned_ens = self.run_hpo(
247310
method=methods,
248311
repo=repo,
249312
n_iterations=40,
313+
n_configs=n_configs,
314+
time_limit=time_limit,
315+
fit_order=fit_order,
316+
seed=seed,
317+
fixed_configs=fixed_configs,
250318
)
251319

252320
tuned["ta_name"] = ta_name

0 commit comments

Comments
 (0)