Skip to content

Commit ca785e5

Browse files
committed
minor update
1 parent db8af12 commit ca785e5

File tree

1 file changed

+55
-39
lines changed

1 file changed

+55
-39
lines changed

tabarena/tabarena/nips2025_utils/compare.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -90,45 +90,12 @@ def compare(
9090
leaderboard_kwargs: dict | None = None,
9191
remove_imputed: bool = False,
9292
):
93-
df_results = df_results.copy()
94-
95-
if isinstance(only_valid_tasks, str):
96-
only_valid_tasks = [only_valid_tasks]
97-
if isinstance(only_valid_tasks, list):
98-
for filter_method in only_valid_tasks:
99-
# Filter to tasks present in a specific method
100-
df_filter = df_results[df_results["method"] == filter_method]
101-
if "imputed" in df_filter.columns:
102-
df_filter = df_filter[df_filter["imputed"] != True]
103-
assert len(df_filter) != 0, \
104-
(f"No method named '{filter_method}' remains to filter to!\n"
105-
f"Available tasks: {list(df_results['method'].unique())}")
106-
df_results = filter_to_valid_tasks(
107-
df_to_filter=df_results,
108-
df_filter=df_filter,
109-
)
110-
111-
if "method_type" not in df_results.columns:
112-
df_results["method_type"] = "baseline"
113-
if "method_subtype" not in df_results.columns:
114-
df_results["method_subtype"] = np.nan
115-
if "config_type" not in df_results.columns:
116-
df_results["config_type"] = np.nan
117-
if "imputed" not in df_results.columns:
118-
df_results["imputed"] = False
119-
120-
if isinstance(fillna, str):
121-
fillna = df_results[df_results["method"] == fillna]
122-
if fillna is not None:
123-
df_results = TabArenaContext.fillna_metrics(
124-
df_to_fill=df_results,
125-
df_fillna=fillna,
126-
)
127-
128-
if remove_imputed:
129-
methods_imputed = df_results.groupby("method")["imputed"].sum()
130-
methods_imputed = list(methods_imputed[methods_imputed > 0].index)
131-
df_results = df_results[~df_results["method"].isin(methods_imputed)]
93+
df_results = prepare_data(
94+
df_results=df_results,
95+
only_valid_tasks=only_valid_tasks,
96+
fillna=fillna,
97+
remove_imputed=remove_imputed,
98+
)
13299

133100
if score_on_val:
134101
error_col = "metric_error_val"
@@ -173,6 +140,55 @@ def is_in(dataset: str, fold: int) -> bool:
173140
return df_filtered
174141

175142

143+
def prepare_data(
144+
df_results: pd.DataFrame,
145+
only_valid_tasks: str | list[str] | None = None,
146+
fillna: str | pd.DataFrame | None = None,
147+
remove_imputed: bool = False,
148+
) -> pd.DataFrame:
149+
df_results = df_results.copy()
150+
151+
if isinstance(only_valid_tasks, str):
152+
only_valid_tasks = [only_valid_tasks]
153+
if isinstance(only_valid_tasks, list):
154+
for filter_method in only_valid_tasks:
155+
# Filter to tasks present in a specific method
156+
df_filter = df_results[df_results["method"] == filter_method]
157+
if "imputed" in df_filter.columns:
158+
df_filter = df_filter[df_filter["imputed"] != True]
159+
assert len(df_filter) != 0, \
160+
(f"No method named '{filter_method}' remains to filter to!\n"
161+
f"Available tasks: {list(df_results['method'].unique())}")
162+
df_results = filter_to_valid_tasks(
163+
df_to_filter=df_results,
164+
df_filter=df_filter,
165+
)
166+
167+
if "method_type" not in df_results.columns:
168+
df_results["method_type"] = "baseline"
169+
if "method_subtype" not in df_results.columns:
170+
df_results["method_subtype"] = np.nan
171+
if "config_type" not in df_results.columns:
172+
df_results["config_type"] = np.nan
173+
if "imputed" not in df_results.columns:
174+
df_results["imputed"] = False
175+
176+
if isinstance(fillna, str):
177+
fillna = df_results[df_results["method"] == fillna]
178+
if fillna is not None:
179+
df_results = TabArenaContext.fillna_metrics(
180+
df_to_fill=df_results,
181+
df_fillna=fillna,
182+
)
183+
184+
if remove_imputed:
185+
methods_imputed = df_results.groupby("method")["imputed"].sum()
186+
methods_imputed = list(methods_imputed[methods_imputed > 0].index)
187+
df_results = df_results[~df_results["method"].isin(methods_imputed)]
188+
189+
return df_results
190+
191+
176192
def subset_tasks(df_results: pd.DataFrame, subset: list[str], folds: list[int] = None) -> pd.DataFrame:
177193
from tabarena.nips2025_utils.fetch_metadata import load_task_metadata
178194

0 commit comments

Comments
 (0)