Skip to content

Commit 9cb0106

Browse files
committed
Add display_name logic
1 parent 2ec8200 commit 9cb0106

File tree

4 files changed

+43
-2
lines changed

4 files changed

+43
-2
lines changed

tabarena/tabarena/nips2025_utils/artifacts/method_metadata.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
artifact_name: str = None,
3636
date: str | None = None,
3737
method_type: Literal["config", "baseline", "portfolio"] = "config",
38+
display_name: str | None = None,
3839
name: str | None = None,
3940
name_suffix: str | None = None,
4041
ag_key: str | None = None,
@@ -64,6 +65,7 @@ def __init__(
6465
model_key = ag_key
6566
self.model_key = model_key
6667
self.name = name
68+
self.display_name = display_name
6769
self.name_suffix = name_suffix
6870
self.config_default = config_default
6971
self.compute = compute
@@ -90,6 +92,15 @@ def __init__(
9092
raise AssertionError(f"Must only specify one of `name` and `name_suffix`.")
9193
self.reference_url = reference_url
9294

95+
def get_display_name(self) -> str:
96+
if self.display_name is not None:
97+
return self.display_name
98+
if self.name is not None:
99+
return self.name
100+
if self.config_type is not None:
101+
return self.config_type
102+
return self.method
103+
93104
@property
94105
def config_type(self) -> str | None:
95106
if self.method_type != "config":

tabarena/tabarena/nips2025_utils/compare.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def compare_on_tabarena(
2424
remove_imputed: bool = False,
2525
tmp_treat_tasks_independently: bool = False,
2626
leaderboard_kwargs: dict | None = None,
27+
**kwargs,
2728
) -> pd.DataFrame:
2829
output_dir = Path(output_dir)
2930
if tabarena_context is None:
@@ -32,6 +33,22 @@ def compare_on_tabarena(
3233
tabarena_context = TabArenaContext(**tabarena_context_kwargs)
3334
task_metadata = tabarena_context.task_metadata
3435

36+
# TODO: only methods that exist in runs
37+
# Pair with (method, artifact_name)
38+
method_rename_map = dict()
39+
method_metadatas = tabarena_context.method_metadata_collection.method_metadata_lst
40+
for m in method_metadatas:
41+
if m.method_type == "config":
42+
display_name = m.get_display_name()
43+
if display_name is not None:
44+
if m.config_type in method_rename_map:
45+
print(
46+
f"WARNING: Multiple display_name values detected for the same config_type={m.config_type!r}"
47+
f"\n\tdisplay_name 1: {method_rename_map[m.config_type]!r}"
48+
f"\n\tdisplay_name 2: {display_name!r}"
49+
)
50+
method_rename_map[m.config_type] = display_name
51+
3552
paper_results = tabarena_context.load_results_paper(
3653
download_results="auto",
3754
)
@@ -46,7 +63,7 @@ def compare_on_tabarena(
4663
else:
4764
df_results = paper_results
4865

49-
kwargs = {}
66+
kwargs = kwargs.copy()
5067
if isinstance(only_valid_tasks, (str, list)):
5168
kwargs["only_valid_tasks"] = only_valid_tasks
5269
elif only_valid_tasks and new_results is not None:
@@ -73,6 +90,7 @@ def compare_on_tabarena(
7390
remove_imputed=remove_imputed,
7491
tmp_treat_tasks_independently=tmp_treat_tasks_independently,
7592
leaderboard_kwargs=leaderboard_kwargs,
93+
method_rename_map=method_rename_map,
7694
**kwargs,
7795
)
7896

@@ -89,6 +107,8 @@ def compare(
89107
tmp_treat_tasks_independently: bool = False, # FIXME: Update
90108
leaderboard_kwargs: dict | None = None,
91109
remove_imputed: bool = False,
110+
method_rename_map: dict | None = None,
111+
**kwargs,
92112
):
93113
df_results = prepare_data(
94114
df_results=df_results,
@@ -109,6 +129,7 @@ def compare(
109129
output_dir=output_dir,
110130
task_metadata=task_metadata,
111131
error_col=error_col,
132+
method_rename_map=method_rename_map,
112133
)
113134

114135
return plotter.eval(
@@ -121,6 +142,7 @@ def compare(
121142
average_seeds=average_seeds,
122143
tmp_treat_tasks_independently=tmp_treat_tasks_independently,
123144
leaderboard_kwargs=leaderboard_kwargs,
145+
**kwargs,
124146
)
125147

126148

tabarena/tabarena/nips2025_utils/tabarena_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def compare(
117117
remove_imputed: bool = False,
118118
tmp_treat_tasks_independently: bool = False,
119119
leaderboard_kwargs: dict | None = None,
120+
**kwargs,
120121
) -> pd.DataFrame:
121122
from tabarena.nips2025_utils.compare import compare_on_tabarena
122123
return compare_on_tabarena(
@@ -132,6 +133,7 @@ def compare(
132133
remove_imputed=remove_imputed,
133134
tmp_treat_tasks_independently=tmp_treat_tasks_independently,
134135
leaderboard_kwargs=leaderboard_kwargs,
136+
**kwargs,
135137
)
136138

137139
@property

tabarena/tabarena/paper/tabarena_evaluator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
folds: list[int] | None = None,
5656
datasets: list[str] | None = None,
5757
problem_types: list[str] | None = None,
58+
method_rename_map: dict[str, str] | None = None,
5859
banned_model_types: list[str] | None = None,
5960
banned_pareto_methods: list[str] | None = None,
6061
elo_bootstrap_rounds: int = 200,
@@ -85,13 +86,16 @@ def __init__(
8586
task_metadata = load_task_metadata()
8687
if banned_pareto_methods is None:
8788
banned_pareto_methods = []
89+
if method_rename_map is None:
90+
method_rename_map = {}
8891
self.output_dir: Path = Path(output_dir)
8992
self.task_metadata = task_metadata
9093
self.method_col = method_col
9194
self.error_col = error_col
9295
self.config_types = config_types
9396
self.figure_file_type = figure_file_type
9497
self.banned_pareto_methods = banned_pareto_methods
98+
self._method_rename_map = method_rename_map
9599

96100
self.datasets = datasets
97101
self.problem_types = problem_types
@@ -825,7 +829,9 @@ def plot_pareto_improvability_vs_time_train(self, leaderboard: pd.DataFrame):
825829
)
826830

827831
def get_method_rename_map(self) -> dict[str, str]:
828-
return get_method_rename_map() # FIXME: Avoid hardcoding
832+
method_rename_map = get_method_rename_map() # FIXME: Avoid hardcoding
833+
method_rename_map.update(self._method_rename_map)
834+
return method_rename_map
829835

830836
def plot_portfolio_ensemble_weights_barplot(self, df_ensemble_weights: pd.DataFrame):
831837
import seaborn as sns

0 commit comments

Comments
 (0)