Skip to content

Commit ec9f859

Browse files
committed
Cache computation of linear regression for numerical stability
1 parent a9dc628 commit ec9f859

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

notebooks_jason/max_of_K_all_models.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3193,11 +3193,22 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
31933193
)
31943194
# Show the plot
31953195
fig.show("png")
3196+
3197+
31963198
# %%
31973199
# plt.set_prop_cycle(color=['red', 'green', 'blue'])
31983200
# default_colors
31993201
# cycler(color=plt.cm.Paired.colors)
32003202
# cycler(color=plt.cm.tab20c.colors)
3203+
# %%
3204+
def do_linear_regression(X, Y):
3205+
model = LinearRegression().fit(X, Y)
3206+
slope = model.coef_[0]
3207+
intercept = model.intercept_
3208+
r_squared = r2_score(Y, model.predict(X))
3209+
return slope, intercept, r_squared
3210+
3211+
32013212
# %%
32023213
plt.rcParams["axes.prop_cycle"] = cycler(color=plt.cm.Paired.colors[::-1])
32033214

@@ -3295,10 +3306,32 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
32953306
X = subgroup["EQKERatioFirstTwoSingularFloat"].values.reshape(-1, 1)
32963307
y = subgroup["normalized-accuracy-bound"].values
32973308

3298-
model = LinearRegression().fit(X, y)
3299-
slope = model.coef_[0]
3300-
intercept = model.intercept_
3301-
r_squared = r2_score(y, model.predict(X))
3309+
# cache for numerical stability
3310+
# model = LinearRegression().fit(X, y)
3311+
# slope = model.coef_[0]
3312+
# intercept = model.intercept_
3313+
# r_squared = r2_score(y, model.predict(X))
3314+
with memoshelve_hf_staged(
3315+
short_name=f"linear_regression_normalized-accuracy-bound-vs-EQKERatioFirstTwoSingularFloat{EXTRA_D_VOCAB_FILE_SUFFIX}"
3316+
) as memoshelve_hf:
3317+
with memoshelve_hf(
3318+
(
3319+
lambda _sing_upper_bound, _attn_err_handling_key, _best_bound_only, _sorted_seeds: do_linear_regression(
3320+
X, y
3321+
)
3322+
),
3323+
"linear_regression_normalized-accuracy-bound-vs-EQKERatioFirstTwoSingularFloat",
3324+
extra_hf_file_suffix=EXTRA_D_VOCAB_FILE_SUFFIX,
3325+
get_hash_mem=(lambda x: x[0]),
3326+
get_hash=str,
3327+
) as memo_do_linear_regression:
3328+
slope, intercept, r_squared = memo_do_linear_regression(
3329+
sing_upper_bound,
3330+
attn_err_handling_key,
3331+
best_bound_only,
3332+
tuple(sorted(subgroup["seed"].values)),
3333+
)
3334+
33023335
attn_err_handling_key_latex = (
33033336
LargestWrongLogitQuadraticConfig.transform_description(
33043337
attn_err_handling_key, latex=True

0 commit comments

Comments
 (0)