Skip to content

Commit d94134b

Browse files
committed
Clean up plot tuning impact code
1 parent 1622ef4 commit d94134b

File tree

1 file changed

+14
-105
lines changed

1 file changed

+14
-105
lines changed

tabarena/tabarena/paper/tabarena_evaluator.py

Lines changed: 14 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def eval(
233233
tmp_treat_tasks_independently: bool = False, # FIXME: Need to make a weighted elo logic
234234
leaderboard_kwargs: dict | None = None,
235235
plot_with_baselines: bool = False,
236+
verbose: bool = True,
236237
) -> pd.DataFrame:
237238
if leaderboard_kwargs is None:
238239
leaderboard_kwargs = {}
@@ -407,10 +408,11 @@ def eval(
407408

408409
n_tasks = len(df_results_rank_compare[[tabarena.task_col, tabarena.seed_column]].drop_duplicates())
409410

410-
print(
411-
f"Evaluating with {len(df_results_rank_compare[tabarena.task_col].unique())} datasets... ({n_tasks} tasks)| problem_types={self.problem_types}, folds={self.folds}")
412-
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000):
413-
print(leaderboard)
411+
if verbose:
412+
print(
413+
f"Evaluating with {len(df_results_rank_compare[tabarena.task_col].unique())} datasets... ({n_tasks} tasks)| problem_types={self.problem_types}, folds={self.folds}")
414+
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000):
415+
print(leaderboard)
414416

415417
# horizontal elo barplot
416418
self.plot_tuning_impact(
@@ -1053,12 +1055,10 @@ def plot_tuning_impact(
10531055

10541056
if imputed_names is None:
10551057
imputed_names = []
1056-
# imputed_names = imputed_names or ['TabPFNv2', 'TabICL']
10571058

10581059
df = df.copy(deep=True)
10591060

10601061
framework_col = "framework_type"
1061-
# framework_col = "framework_name"
10621062

10631063
groupby_columns_extra = ["dataset"]
10641064

@@ -1073,7 +1073,6 @@ def plot_tuning_impact(
10731073
elif use_score:
10741074
lower_is_better = False
10751075
df["normalized-score"] = 1 - df[metric]
1076-
# df_plot_w_mean_per_dataset["normalized-score"] = 1 - df_plot_w_mean_per_dataset["normalized-error"]
10771076
metric = "normalized-score"
10781077
else:
10791078
metric = metric
@@ -1102,12 +1101,6 @@ def plot_tuning_impact(
11021101
df = df[df["tune_method"].isin(plot_tune_types) | df[self.method_col].isin(baselines)]
11031102

11041103
df_plot = df[df["framework_type"].isin(framework_types)]
1105-
# df_plot = df_plot[~df_plot["framework_type"].isin(imputed_names)]
1106-
1107-
# pd.set_option('display.max_columns', None) # todo
1108-
# print(f'{df_plot.head()=}')
1109-
1110-
# df_plot_w_mean_2 = df_plot.groupby(["framework_type", "tune_method"])[metric].mean().reset_index()
11111104

11121105
df_plot_w_mean_per_dataset = df_plot.groupby(["framework_type", "tune_method", *groupby_columns_extra])[
11131106
metric].mean().reset_index()
@@ -1136,27 +1129,8 @@ def plot_tuning_impact(
11361129
framework_type_order = list(df_plot_mean_dedupe["framework_type"].to_list())
11371130
framework_type_order.reverse()
11381131

1139-
# change to names
1140-
# df_plot_w_mean_per_dataset["framework_type"] = df_plot_w_mean_per_dataset["framework_type"].map(f_map_type_name)
1141-
1142-
# sns.set_color_codes("pastel")
1143-
# with sns.plotting_context("notebook", font_scale=0.8, rc={
1144-
# "pgf.texsystem": "pdflatex",
1145-
# 'font.family': 'serif',
1146-
# 'font.size': 10.95,
1147-
# 'text.usetex': True,
1148-
# 'pgf.rcfonts': False,
1149-
# # 'legend.framealpha': 0.5,
1150-
# 'text.latex.preamble': r'\usepackage{times} \usepackage{amsmath} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{xcolor}'
1151-
# }):
1152-
11531132
with sns.axes_style("whitegrid"):
1154-
# with plt.rc_context({'font.family': 'serif', "text.usetex": True, 'font.size': 12, 'axes.labelsize': 12, 'xtick.labelsize': 12}):
1155-
with plt.rc_context(self.rc_context_params
1156-
# | figsizes.neurips2024(height_to_width_ratio=0.8)
1157-
):
1158-
# with plt.rc_context(fontsizes.neurips2024() | fonts.neurips2024()):
1159-
# with plt.rc_context(figsizes.neurips2024(height_to_width_ratio=0.8)):
1133+
with plt.rc_context(self.rc_context_params):
11601134
colors = sns.color_palette("pastel").as_hex()
11611135
errcolors = sns.color_palette("deep").as_hex()
11621136

@@ -1178,7 +1152,6 @@ def plot_tuning_impact(
11781152
# figsize = None
11791153

11801154
fig, ax = plt.subplots(1, 1, figsize=figsize)
1181-
# fig, ax = plt.subplots(1, 1)
11821155

11831156
if use_y:
11841157
baseline_func = ax.axvline
@@ -1198,7 +1171,6 @@ def plot_tuning_impact(
11981171
to_plot = [
11991172
dict(
12001173
x=pos, y=y,
1201-
# hue="tune_method", # palette=["m", "g", "r],
12021174
label="Tuned + Ensembled",
12031175
data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "tuned_ensembled"],
12041176
ax=ax,
@@ -1208,7 +1180,6 @@ def plot_tuning_impact(
12081180
),
12091181
# dict(
12101182
# x=x, y=y,
1211-
# # hue="tune_method", # palette=["m", "g", "r],
12121183
# label="Default (Holdout)",
12131184
# data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "holdout"], ax=ax,
12141185
# order=framework_type_order,
@@ -1218,7 +1189,6 @@ def plot_tuning_impact(
12181189
# ),
12191190
# dict(
12201191
# x=x, y=y,
1221-
# # hue="tune_method", # palette=["m", "g", "r],
12221192
# label="Tuned (Holdout)",
12231193
# data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "holdout_tuned"], ax=ax,
12241194
# order=framework_type_order,
@@ -1228,7 +1198,6 @@ def plot_tuning_impact(
12281198
# ),
12291199
dict(
12301200
x=pos, y=y,
1231-
# hue="tune_method", # palette=["m", "g", "r],
12321201
label="Tuned",
12331202
data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "tuned"], ax=ax,
12341203
order=framework_type_order,
@@ -1238,7 +1207,6 @@ def plot_tuning_impact(
12381207
),
12391208
dict(
12401209
x=pos, y=y,
1241-
# hue="tune_method", # palette=["m", "g", "r],
12421210
label="Default",
12431211
data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "default"], ax=ax,
12441212
order=framework_type_order, color=colors[0],
@@ -1248,7 +1216,6 @@ def plot_tuning_impact(
12481216
),
12491217
dict(
12501218
x=pos, y=y,
1251-
# hue="tune_method", # palette=["m", "g", "r],
12521219
label="Tuned + Ensembled (Holdout)",
12531220
data=df_plot_w_mean_per_dataset[
12541221
df_plot_w_mean_per_dataset["tune_method"] == "holdout_tuned_ensembled"], ax=ax,
@@ -1260,34 +1227,13 @@ def plot_tuning_impact(
12601227
),
12611228
# dict(
12621229
# x=x, y=y,
1263-
# # hue="tune_method", # palette=["m", "g", "r],
12641230
# label="Best",
12651231
# data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "best"], ax=ax,
12661232
# order=framework_type_order, color=colors[3],
12671233
# width=0.55, linewidth=linewidth,
12681234
# err_kws={"color": errcolors[3]},
12691235
# alpha=1.0,
12701236
# ),
1271-
# dict(
1272-
# x=x, y=y,
1273-
# # hue="tune_method", # palette=["m", "g", "r],
1274-
# label="Tuned (4h)",
1275-
# data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "tuned_4h"], ax=ax,
1276-
# order=framework_type_order,
1277-
# color=colors[4],
1278-
# width=0.5, linewidth=linewidth,
1279-
# err_kws={"color": errcolors[4]},
1280-
# ),
1281-
# dict(
1282-
# x=x, y=y,
1283-
# # hue="tune_method", # palette=["m", "g", "r],
1284-
# label="Tuned + Ensembled (4h)",
1285-
# data=df_plot_w_mean_per_dataset[df_plot_w_mean_per_dataset["tune_method"] == "tuned_ensembled_4h"], ax=ax,
1286-
# order=framework_type_order, color=colors[5],
1287-
# width=0.4,
1288-
# err_kws={"color": errcolors[5]},
1289-
# ),
1290-
12911237
]
12921238

12931239
if use_score:
@@ -1301,8 +1247,6 @@ def plot_tuning_impact(
13011247
plot_line["width"] = 0.6 * 1.3
13021248
else:
13031249
plot_line["width"] = width * 1.3
1304-
# plot_line["color"] = color
1305-
# plot_line["err_kws"] = err_kws
13061250

13071251
for plot_line in to_plot:
13081252
boxplot = sns.barplot(**plot_line)
@@ -1313,19 +1257,9 @@ def plot_tuning_impact(
13131257
boxplot.set(xlabel=None, ylabel='Elo' if metric=='elo' else 'Normalized score') # remove method in the x-axis
13141258
# boxplot.set_title("Effect of tuning and ensembling")
13151259

1316-
# # FIXME: (Nick) HACK, otherwise it isn't in the plot, don't know why
1317-
# if use_elo:
1318-
# if baseline_means and "Portfolio-N200 (ensemble) (4h)" in baselines:
1319-
# max_baseline_mean = max([v for k, v in baseline_means.items()])
1320-
# if ylim is not None:
1321-
# ylim[1] = max_baseline_mean + 50
1322-
# if xlim is not None:
1323-
# xlim[1] = max_baseline_mean + 50
1324-
13251260
# do this before setting x/y limits
13261261
for baseline_idx, (baseline, color) in enumerate(zip(baselines, baseline_colors)):
13271262
baseline_mean = baseline_means[baseline]
1328-
# baseline_func(baseline_mean, label=baseline, color=color, linewidth=2.0, ls="--")
13291263
baseline_func(baseline_mean, color=color, linewidth=2.0, ls="--", zorder=-10)
13301264

13311265
if baseline == 'Portfolio-N200 (ensemble) (4h)':
@@ -1411,26 +1345,11 @@ def plot_tuning_impact(
14111345
else:
14121346
plt.xlim(-0.5, len(boxplot.get_xticklabels()) - 0.5)
14131347

1414-
1415-
# ax.legend(loc="upper center", ncol=5)
1416-
# these are not the final legend parameters, see below
14171348
ax.legend(loc="upper center", bbox_to_anchor=[0.5, 1.02])
14181349

14191350
# reordering the labels
14201351
handles, labels = ax.get_legend_handles_labels()
14211352

1422-
# this doesn't work, it also removes the hatch from the actual bars in the plot
1423-
# for handle in handles:
1424-
# patches = []
1425-
# if isinstance(handle, Patch):
1426-
# patches = [handle]
1427-
# elif isinstance(handle, BarContainer):
1428-
# patches = handle.patches
1429-
# for patch in patches:
1430-
# # remove hatch from existing handles
1431-
# # It can be present if one of the imputed methods is the best method, e.g., for multiclass
1432-
# patch.set(hatch=None)
1433-
14341353
if has_imputed:
14351354
# Create a custom legend patch for "imputed"
14361355
imputed_patch = Patch(facecolor='gray', edgecolor='white', hatch='xx', label='Partially imputed')
@@ -1446,28 +1365,18 @@ def plot_tuning_impact(
14461365
labels = [labels[i] for i in valid_idxs]
14471366
handles = [handles[i] for i in valid_idxs]
14481367

1449-
# specify order
1450-
# len_baselines = len(baselines)
1451-
# len_baselines = 0 # we don't put them in the legend anymore
1452-
# num_other = len(labels) - len_baselines
1453-
# order = [n + len_baselines for n in range(num_other)] + [n for n in range(len_baselines)]
1454-
# order = [3, 4, 5, 0, 1, 2]
14551368
order = list(range(len(labels)))
14561369
order = list(reversed(order))
1457-
# if len(order) == 3:
1458-
# order = [2, 1, 0]
14591370

14601371
# pass handle & labels lists along with order as below
1461-
ax.legend([handles[i] for i in order], [labels[i] for i in order], loc="lower center",
1462-
ncol=(len(labels)+1)//2 if has_imputed and use_y else len(labels),
1463-
bbox_to_anchor=[0.35 if use_y else 0.5, 1.05])
1464-
1465-
# if use_y:
1466-
# boxplot.margins(y=0.05)
1467-
# else:
1468-
# boxplot.margins(x=0.05)
1372+
ax.legend(
1373+
[handles[i] for i in order],
1374+
[labels[i] for i in order],
1375+
loc="lower center",
1376+
ncol=(len(labels)+1)//2 if has_imputed and use_y else len(labels),
1377+
bbox_to_anchor=[0.35 if use_y else 0.5, 1.05],
1378+
)
14691379

1470-
# ax.legend(bbox_to_anchor=[0.1, 0.5], loc='center left', ncol=5)
14711380
plt.tight_layout()
14721381

14731382
if save_prefix:

0 commit comments

Comments
 (0)