Skip to content

Commit 4cc7230

Browse files
committed
Add TabArenaContext.combine_hpo
1 parent 1923ff1 commit 4cc7230

File tree

2 files changed

+113
-6
lines changed

2 files changed

+113
-6
lines changed

tabarena/tabarena/nips2025_utils/tabarena_context.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pandas as pd
99

10+
from bencheval.website_format import format_leaderboard
1011
from tabarena.benchmark.result import BaselineResult
1112
from tabarena.utils.pickle_utils import fetch_all_pickles
1213
from tabarena.nips2025_utils.fetch_metadata import load_task_metadata
@@ -203,16 +204,94 @@ def generate_repo_holdout(self, method: str) -> Path:
203204
repo.to_dir(path_processed)
204205
return path_processed
205206

207+
def combine_hpo(
208+
self,
209+
methods: list[str],
210+
new_config_type: str,
211+
ta_name: str,
212+
ta_suite: str,
213+
method_default: str | None = None,
214+
repo: EvaluationRepository | None = None,
215+
) -> pd.DataFrame:
216+
"""
217+
Perform HPO across multiple methods
218+
219+
Returns default, tuned, and tuned + ensembled results.
220+
"""
221+
if method_default is None:
222+
method_default = methods[0]
223+
if repo is None:
224+
repo = self.load_repo(methods=methods)
225+
226+
config_type_default = self.method_metadata(method_default).config_type
227+
simulator = PaperRunTabArena(repo=repo, backend=self.backend)
228+
config_default = simulator._config_default(config_type=config_type_default, use_first_if_missing=True)
229+
if config_default is not None:
230+
default = simulator.run_config_default(model_type=config_type_default)
231+
default = default.rename(columns={"framework": "method"})
232+
default["ta_name"] = ta_name
233+
default["ta_suite"] = ta_suite
234+
default["config_type"] = new_config_type
235+
default["method"] = f"{new_config_type} (default)"
236+
else:
237+
default = None
238+
239+
tuned = self.run_hpo(
240+
method=methods,
241+
repo=repo,
242+
n_iterations=1,
243+
)
244+
245+
tuned_ens = self.run_hpo(
246+
method=methods,
247+
repo=repo,
248+
n_iterations=40,
249+
)
250+
251+
tuned["ta_name"] = ta_name
252+
tuned["ta_suite"] = ta_suite
253+
tuned["config_type"] = new_config_type
254+
tuned["method"] = f"{new_config_type} (tuned)"
255+
tuned_ens["ta_name"] = ta_name
256+
tuned_ens["ta_suite"] = ta_suite
257+
tuned_ens["config_type"] = new_config_type
258+
tuned_ens["method"] = f"{new_config_type} (tuned + ensemble)"
259+
260+
results_hpo_comb = pd.concat([
261+
default,
262+
tuned,
263+
tuned_ens,
264+
], ignore_index=True)
265+
266+
return results_hpo_comb
267+
206268
def run_hpo(
207269
self,
208-
method: str,
209-
repo: EvaluationRepository,
270+
method: str | list[str],
271+
repo: EvaluationRepository = None,
210272
n_iterations: int = 40,
211273
n_configs: int | None = None,
212274
time_limit: float | None = None,
213275
fit_order: Literal["original", "random"] = "original",
214276
seed: int = 0,
277+
**kwargs,
215278
) -> pd.DataFrame:
279+
if not isinstance(method, list):
280+
method = [method]
281+
valid_methods = self.methods
282+
if repo is None:
283+
repo = self.load_repo(methods=method)
284+
method_new = []
285+
for m in method:
286+
if m in valid_methods:
287+
method_metadata = self.method_metadata(method=m)
288+
config_type = method_metadata.config_type
289+
else:
290+
config_type = m
291+
method_new.append(config_type)
292+
method = method_new
293+
if len(method) == 1:
294+
method = method[0]
216295
simulator = PaperRunTabArena(repo=repo, backend=self.backend)
217296
df_results_family_hpo = simulator.run_ensemble_config_type(
218297
config_type=method,
@@ -221,11 +300,16 @@ def run_hpo(
221300
time_limit=time_limit,
222301
fit_order=fit_order,
223302
seed=seed,
303+
**kwargs,
224304
)
225305
df_results_family_hpo = df_results_family_hpo.rename(columns={
226306
"framework": "method",
227307
})
228-
df_results_family_hpo["method"] = f"HPO-N{n_configs}-{method}"
308+
name = "HPO"
309+
if n_configs is not None:
310+
name += f"-N{n_configs}"
311+
name += f"-{method}"
312+
df_results_family_hpo["method"] = name
229313
return df_results_family_hpo
230314

231315
# FIXME: WIP
@@ -536,6 +620,24 @@ def generate_per_dataset_tables(
536620
save_path=Path(save_path),
537621
)
538622

623+
def leaderboard_to_website_format(
624+
self,
625+
leaderboard: pd.DataFrame,
626+
**kwargs,
627+
) -> pd.DataFrame:
628+
method_metadata_info = self.method_metadata_collection.info()
629+
method_metadata_info = method_metadata_info.rename(
630+
columns={
631+
"method": "ta_name",
632+
"artifact_name": "ta_suite",
633+
}
634+
)
635+
return format_leaderboard(
636+
df_leaderboard=leaderboard,
637+
method_metadata_info=method_metadata_info,
638+
**kwargs,
639+
)
640+
539641
def load_config_results_multi(
540642
self,
541643
method_metadata_lst: list[MethodMetadata] | None = None,

tabarena/tabarena/paper/paper_runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def run_hpo_by_family(self, include_uncapped: bool = False, include_4h: bool = T
8282

8383
def run_ensemble_config_type(
8484
self,
85-
config_type: str,
85+
config_type: str | list[str],
8686
n_iterations: int,
8787
n_configs: int = None,
8888
fixed_configs: list[str] | None = None,
@@ -93,7 +93,12 @@ def run_ensemble_config_type(
9393
) -> pd.DataFrame:
9494
# FIXME: Don't recompute this each call, implement `self.repo.configs(config_types=[config_type])`
9595
config_type_groups = self.get_config_type_groups()
96-
configs = config_type_groups[config_type]
96+
if isinstance(config_type, list):
97+
configs = []
98+
for ct in config_type:
99+
configs += config_type_groups[ct]
100+
else:
101+
configs = config_type_groups[config_type]
97102

98103
if fixed_configs is not None:
99104
for c in fixed_configs:
@@ -126,7 +131,7 @@ def run_ensemble_config_type(
126131
else:
127132
method_subtype = "tuned_ensemble"
128133
df_results_family_hpo["method_subtype"] = method_subtype
129-
df_results_family_hpo["config_type"] = config_type
134+
df_results_family_hpo["config_type"] = str(config_type)
130135

131136
method_metadata = dict(
132137
n_iterations=n_iterations,

0 commit comments

Comments
 (0)