Skip to content

Commit fee27b4

Browse files
committed
WIP on baselines
1 parent a69558f commit fee27b4

File tree

1 file changed

+167
-98
lines changed

1 file changed

+167
-98
lines changed

notebooks_jason/max_of_K_all_models.py

Lines changed: 167 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@
160160
default=True,
161161
help="Output plots shared across seeds",
162162
)
163-
cli_args = parser.parse_args(None if ipython is None else ["--ignore-csv"])
163+
cli_args = parser.parse_args(
164+
None if ipython is None else ["--ignore-csv", "--K", "4", "--brute-force"]
165+
)
164166
# %%
165167
#!sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super pdfcrop optipng pngcrush
166168
# %%
@@ -3569,112 +3571,179 @@ def do_linear_regression(X, Y):
35693571
plt.rcParams["axes.prop_cycle"] = cycler(color=plt.cm.Paired.colors)
35703572
for frontier_only in (True, False):
35713573
for norm, normt in (("", ""), ("normalized-", "Normalized ")):
3572-
key = f"{normt.strip()}AccuracyBoundVsFLOPs{'FrontierOnly' if frontier_only else ''}"
3573-
data = (
3574-
combined_df[combined_df["frontier"] == True]
3575-
if frontier_only
3576-
else combined_df
3577-
)
3578-
data = data[["proof-flop-estimate", f"{norm}accuracy-bound", "group"]].copy()
3579-
data = double_singleton_groups(data.drop_duplicates(), column="group")
3580-
data = data.sort_values(
3581-
by=["group", f"{norm}accuracy-bound", "proof-flop-estimate"]
3582-
)
3583-
discontinuous_x = (
3584-
data[(data["group"] == "brute force") | (data["group"] == "cubic")][
3585-
"proof-flop-estimate"
3586-
].mean(),
3587-
)
3588-
compress_data = lambda values: (
3589-
f"{values.item() / 2 ** int(math.log2(values.item()))} \\cdot 2^{{{int(math.log2(values.item()))}}}"
3590-
if len(values) == 1
3591-
else f"({pm_mean_std(values / 2 ** int(math.log2(values.mean())))}) \\cdot 2^{{{int(math.log2(values.mean()))}}}"
3592-
)
3593-
print(
3594-
[
3595-
(
3596-
compress_data(
3597-
data[data["group"] == c]["proof-flop-estimate"].unique()
3598-
),
3599-
category_name_remap[c],
3600-
)
3601-
for c in category_order
3602-
if len(data[data["group"] == c]["proof-flop-estimate"]) > 0
3603-
]
3604-
)
3605-
data["group"] = data["group"].map(category_name_remap)
3606-
if (DISPLAY_PLOTS or SAVE_PLOTS) and SHARED_PLOTS:
3607-
markersize = (
3608-
plt.rcParams["lines.markersize"] / 8 if not frontier_only else None
3574+
for include_baseline in (True, False): # , True):
3575+
key = f"{normt.strip()}AccuracyBoundVsFLOPs{'FrontierOnly' if frontier_only else ''}{'WithBaseline' if include_baseline else ''}"
3576+
data = (
3577+
combined_df[combined_df["frontier"] == True]
3578+
if frontier_only
3579+
else combined_df
36093580
)
3610-
latex_externalize_tables[key] = True
3611-
latex_figures[key] = fig = scatter(
3612-
data,
3613-
x="proof-flop-estimate",
3614-
y=f"{norm}accuracy-bound",
3615-
color="group",
3616-
title="", # "Pareto Frontier" if frontier_only else "",
3617-
log_x=2,
3618-
reverse_xaxis=False,
3619-
xaxis="FLOPs to Verify Proof (approximate)",
3620-
yaxis=f"{normt}Accuracy Bound",
3621-
color_order=[category_name_remap[c] for c in category_order],
3622-
markersize=markersize,
3623-
plot_with=PLOT_WITH,
3624-
renderer=RENDERER,
3625-
show=DISPLAY_PLOTS,
3581+
data = data[
3582+
["proof-flop-estimate", f"{norm}accuracy-bound", "group"]
3583+
].copy()
3584+
data = double_singleton_groups(data.drop_duplicates(), column="group")
3585+
data = data.sort_values(
3586+
by=["group", f"{norm}accuracy-bound", "proof-flop-estimate"]
36263587
)
3627-
latex_externalize_tables[f"{key}DiscontinuousX"] = True
3628-
latex_figures[f"{key}DiscontinuousX"] = fig = scatter(
3629-
data,
3630-
x="proof-flop-estimate",
3631-
y=f"{norm}accuracy-bound",
3632-
color="group",
3633-
title="", # "Pareto Frontier" if frontier_only else "",
3634-
log_x=2,
3635-
reverse_xaxis=False,
3636-
xaxis="FLOPs to Verify Proof (approximate)",
3637-
yaxis=f"{normt}Accuracy Bound",
3638-
color_order=[category_name_remap[c] for c in category_order],
3639-
markersize=markersize,
3640-
discontinuous_x=discontinuous_x,
3641-
plot_with=PLOT_WITH,
3642-
renderer=RENDERER,
3643-
show=DISPLAY_PLOTS,
3588+
discontinuous_x = (
3589+
data[(data["group"] == "brute force") | (data["group"] == "cubic")][
3590+
"proof-flop-estimate"
3591+
].mean(),
3592+
)
3593+
compress_data = lambda values: (
3594+
f"{values.item() / 2 ** int(math.log2(values.item()))} \\cdot 2^{{{int(math.log2(values.item()))}}}"
3595+
if len(values) == 1
3596+
else f"({pm_mean_std(values / 2 ** int(math.log2(values.mean())))}) \\cdot 2^{{{int(math.log2(values.mean()))}}}"
36443597
)
3598+
print(
3599+
[
3600+
(
3601+
compress_data(
3602+
data[data["group"] == c]["proof-flop-estimate"].unique()
3603+
),
3604+
category_name_remap[c],
3605+
)
3606+
for c in category_order
3607+
if len(data[data["group"] == c]["proof-flop-estimate"]) > 0
3608+
]
3609+
)
3610+
brute_force_data = data[data["group"] == "brute-force"]
3611+
brute_force_slope = (
3612+
brute_force_data[f"{norm}accuracy-bound"]
3613+
/ brute_force_data["proof-flop-estimate"]
3614+
).mean()
3615+
min_flop = data["proof-flop-estimate"].min()
3616+
max_flop = data["proof-flop-estimate"].max()
3617+
data["group"] = data["group"].map(category_name_remap)
3618+
data["linestyle"] = ""
3619+
if include_baseline:
3620+
baseline_categories = [] # ["brute-force linear baseline"]
3621+
x_vals = np.logspace(np.log2(min_flop), np.log2(max_flop), 100, base=2)
3622+
y_vals = brute_force_slope * x_vals
3623+
baseline_df = pd.DataFrame(
3624+
{
3625+
"proof-flop-estimate": x_vals,
3626+
f"{norm}accuracy-bound": y_vals,
3627+
"group": "brute-force linear baseline",
3628+
"linestyle": "dotted",
3629+
}
3630+
)
3631+
# data = pd.concat([data, baseline_df], ignore_index=True)
3632+
else:
3633+
baseline_categories = []
3634+
if (DISPLAY_PLOTS or SAVE_PLOTS) and SHARED_PLOTS:
3635+
markersize = (
3636+
plt.rcParams["lines.markersize"] / 8 if not frontier_only else None
3637+
)
3638+
latex_externalize_tables[key] = True
3639+
for discontinuous_x_val, discontinuous_x_t in (
3640+
((), ""),
3641+
(discontinuous_x, "DiscontinuousX"),
3642+
):
3643+
latex_figures[f"{key}{discontinuous_x_t}"] = fig = scatter(
3644+
data,
3645+
x="proof-flop-estimate",
3646+
y=f"{norm}accuracy-bound",
3647+
color="group",
3648+
title="", # "Pareto Frontier" if frontier_only else "",
3649+
log_x=2,
3650+
reverse_xaxis=False,
3651+
xaxis="FLOPs to Verify Proof (approximate)",
3652+
yaxis=f"{normt}Accuracy Bound",
3653+
color_order=baseline_categories
3654+
+ [category_name_remap[c] for c in category_order],
3655+
markersize=markersize,
3656+
discontinuous_x=discontinuous_x_val,
3657+
plot_with=PLOT_WITH,
3658+
renderer=RENDERER,
3659+
show=DISPLAY_PLOTS and not include_baseline,
3660+
)
3661+
if include_baseline:
3662+
x_vals = np.logspace(
3663+
np.log2(min_flop), np.log2(max_flop), 100, base=2
3664+
)
3665+
y_vals = brute_force_slope * x_vals
3666+
match PLOT_WITH:
3667+
case "plotly":
3668+
fig.add_scatter(
3669+
x=x_vals,
3670+
y=y_vals,
3671+
mode="lines",
3672+
line=dict(dash="dot", color="gold"),
3673+
name="brute-force linear baseline",
3674+
)
3675+
if DISPLAY_PLOTS:
3676+
fig.show(RENDERER)
3677+
case "matplotlib":
3678+
ax = fig.gca()
3679+
ax.plot(
3680+
x_vals,
3681+
y_vals,
3682+
linestyle="dotted",
3683+
color="gold",
3684+
label="brute-force linear baseline",
3685+
)
3686+
if not discontinuous_x_val:
3687+
ax.legend(
3688+
loc="center left", bbox_to_anchor=(1, 0.5)
3689+
)
3690+
fig.tight_layout()
3691+
if DISPLAY_PLOTS:
3692+
plt.figure(fig)
3693+
plt.show()
3694+
case _:
3695+
assert False, PLOT_WITH
36453696

36463697

36473698
# %%
36483699
for norm, normt in (("", ""), ("normalized-", "Normalized ")):
3700+
for include_baseline in (False, True):
3701+
data = combined_df[
3702+
[
3703+
"proof-flop-estimate",
3704+
f"{norm}accuracy-bound",
3705+
"group",
3706+
"frontier",
3707+
"tricks",
3708+
]
3709+
].copy()
3710+
data = double_singleton_groups(data.drop_duplicates(), column="group")
3711+
# data["group"] = data["group"].map({k:k[:7] for k in set(data["group"])})
3712+
brute_force_data = data[data["group"] == "brute-force"]
3713+
brute_force_slope = (
3714+
brute_force_data[f"{norm}accuracy-bound"]
3715+
/ brute_force_data["proof-flop-estimate"]
3716+
).mean()
3717+
min_flop = brute_force_data["proof-flop-estimate"].min()
3718+
max_flop = brute_force_data["proof-flop-estimate"].max()
3719+
if DISPLAY_PLOTS:
3720+
fig = px.scatter(
3721+
data,
3722+
x="proof-flop-estimate",
3723+
y=f"{norm}accuracy-bound",
3724+
symbol="group",
3725+
title=f"Scatter Plot of Proof Flop Estimate vs {normt}Accuracy Bound (Logarithmic X-Axis)",
3726+
log_x=True,
3727+
color="tricks",
3728+
# symbol_map={True: "diamond", False: "circle"},
3729+
# legend=False,
3730+
)
3731+
if include_baseline:
3732+
x_vals = np.logspace(np.log10(min_flop), np.log10(max_flop), 100)
3733+
y_vals = brute_force_slope * x_vals
3734+
fig.add_scatter(
3735+
x=x_vals,
3736+
y=y_vals,
3737+
mode="lines",
3738+
line=dict(dash="dot", color="gold"),
3739+
name="brute-force linear baseline",
3740+
)
36493741

3650-
data = combined_df[
3651-
[
3652-
"proof-flop-estimate",
3653-
f"{norm}accuracy-bound",
3654-
"group",
3655-
"frontier",
3656-
"tricks",
3657-
]
3658-
].copy()
3659-
data = double_singleton_groups(data.drop_duplicates(), column="group")
3660-
# data["group"] = data["group"].map({k:k[:7] for k in set(data["group"])})
3661-
if DISPLAY_PLOTS:
3662-
fig = px.scatter(
3663-
data,
3664-
x="proof-flop-estimate",
3665-
y=f"{norm}accuracy-bound",
3666-
symbol="group",
3667-
title=f"Scatter Plot of Proof Flop Estimate vs {normt}Accuracy Bound (Logarithmic X-Axis)",
3668-
log_x=True,
3669-
color="tricks",
3670-
# symbol_map={True: "diamond", False: "circle"},
3671-
# legend=False,
3672-
)
3673-
fig.update_layout(showlegend=False)
3674-
# Flip the x-axis
3675-
fig.update_layout(xaxis=dict(autorange="reversed"))
3742+
fig.update_layout(showlegend=False)
3743+
# Flip the x-axis
3744+
fig.update_layout(xaxis=dict(autorange="reversed"))
36763745

3677-
fig.show()
3746+
fig.show()
36783747

36793748
# %%
36803749
latex_values["AllModelsHEADSHA"] = git.get_head_sha(short=False)

0 commit comments

Comments
 (0)