Skip to content

Commit c381db9

Browse files
committed
Always compute axis limits for TeX
1 parent 06b1c8d commit c381db9

File tree

2 files changed

+168
-77
lines changed

2 files changed

+168
-77
lines changed

gbmi/exp_max_of_n/plot.py

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,91 @@ def compute_irrelevant(
461461
}
462462

463463

464+
@torch.no_grad()
465+
def compute_basic_interpretation_axis_limits(
466+
model: HookedTransformer,
467+
*,
468+
include_uncentered: bool = False,
469+
include_equals_OV: bool = False,
470+
includes_eos: Optional[bool] = None,
471+
plot_with: Literal["plotly", "matplotlib"] = "plotly",
472+
) -> Tuple[dict, dict[str, float]]:
473+
cached_data = {}
474+
axis_limits = {
475+
"OV_zmin": np.inf,
476+
"OV_zmax": -np.inf,
477+
"QK_zmin": np.inf,
478+
"QK_zmax": -np.inf,
479+
"OVCentered_zmin": np.inf,
480+
"OVCentered_zmax": -np.inf,
481+
"QKWithAttnScale_zmin": np.inf,
482+
"QKWithAttnScale_zmax": -np.inf,
483+
}
484+
if includes_eos is None:
485+
includes_eos = model.cfg.d_vocab != model.cfg.d_vocab_out
486+
title_kind = "html" if plot_with == "plotly" else "latex"
487+
for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
488+
QK = compute_QK(
489+
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
490+
)
491+
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
492+
[axis_limits[f"QK{attn_scale}_zmin"], QK["data"].min()]
493+
)
494+
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
495+
[axis_limits[f"QK{attn_scale}_zmax"], QK["data"].max()]
496+
)
497+
cached_data[("QK", attn_scale)] = QK
498+
499+
if include_uncentered:
500+
OV = compute_OV(model, centered=False, includes_eos=includes_eos)
501+
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], OV["data"].min()])
502+
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], OV["data"].max()])
503+
cached_data[("OV", False)] = OV
504+
505+
OV = compute_OV(model, centered=True, includes_eos=includes_eos)
506+
axis_limits["OVCentered_zmin"] = np.min(
507+
[axis_limits["OVCentered_zmin"], OV["data"].min()]
508+
)
509+
axis_limits["OVCentered_zmax"] = np.max(
510+
[axis_limits["OVCentered_zmax"], OV["data"].max()]
511+
)
512+
cached_data[("OV", True)] = OV
513+
514+
for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
515+
pos_QK = compute_QK_by_position(
516+
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
517+
)
518+
cached_data[("pos_QK", attn_scale)] = pos_QK
519+
if includes_eos:
520+
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
521+
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
522+
)
523+
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
524+
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
525+
)
526+
else:
527+
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
528+
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
529+
)
530+
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
531+
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
532+
)
533+
534+
irrelevant = compute_irrelevant(
535+
model,
536+
include_equals_OV=include_equals_OV,
537+
includes_eos=includes_eos,
538+
title_kind=title_kind,
539+
)
540+
cached_data["irrelevant"] = irrelevant
541+
for key, data in irrelevant["data"].items():
542+
if len(data.shape) == 2:
543+
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], data.min()])
544+
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], data.max()])
545+
546+
return cached_data, axis_limits
547+
548+
464549
@torch.no_grad()
465550
def display_basic_interpretation(
466551
model: HookedTransformer,
@@ -485,34 +570,26 @@ def display_basic_interpretation(
485570
plot_with: Literal["plotly", "matplotlib"] = "plotly",
486571
renderer: Optional[str] = None,
487572
show: bool = True,
573+
cached_data: Optional[dict] = None,
574+
axis_limits: Optional[dict[str, float]] = None,
488575
) -> Tuple[dict[str, Union[go.Figure, matplotlib.figure.Figure]], dict[str, float]]:
576+
if cached_data is None:
577+
cached_data, axis_limits = compute_basic_interpretation_axis_limits(
578+
model,
579+
include_uncentered=include_uncentered,
580+
include_equals_OV=include_equals_OV,
581+
includes_eos=includes_eos,
582+
plot_with=plot_with,
583+
)
489584
QK_cmap = colorscale_to_cmap(QK_colorscale)
490585
QK_SVD_cmap = colorscale_to_cmap(QK_SVD_colorscale)
491586
OV_cmap = colorscale_to_cmap(OV_colorscale)
492587
if includes_eos is None:
493588
includes_eos = model.cfg.d_vocab != model.cfg.d_vocab_out
494589
result = {}
495-
axis_limits = {
496-
"OV_zmin": np.inf,
497-
"OV_zmax": -np.inf,
498-
"QK_zmin": np.inf,
499-
"QK_zmax": -np.inf,
500-
"OVCentered_zmin": np.inf,
501-
"OVCentered_zmax": -np.inf,
502-
"QKWithAttnScale_zmin": np.inf,
503-
"QKWithAttnScale_zmax": -np.inf,
504-
}
590+
title_kind = "html" if plot_with == "plotly" else "latex"
505591
for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
506-
QK = compute_QK(
507-
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
508-
)
509-
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
510-
[axis_limits[f"QK{attn_scale}_zmin"], QK["data"].min()]
511-
)
512-
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
513-
[axis_limits[f"QK{attn_scale}_zmax"], QK["data"].max()]
514-
)
515-
title_kind = "html" if plot_with == "plotly" else "latex"
592+
QK = cached_data[("QK", attn_scale)]
516593
if includes_eos:
517594
match plot_with:
518595
case "plotly":
@@ -567,9 +644,7 @@ def display_basic_interpretation(
567644
result[f"EQKE{attn_scale}"] = fig_qk
568645

569646
if include_uncentered:
570-
OV = compute_OV(model, centered=False, includes_eos=includes_eos)
571-
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], OV["data"].min()])
572-
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], OV["data"].max()])
647+
OV = cached_data[("OV", False)]
573648
fig_ov = imshow(
574649
OV["data"],
575650
title=OV["title"][title_kind],
@@ -585,13 +660,7 @@ def display_basic_interpretation(
585660
show=show,
586661
)
587662
result["EVOU"] = fig_ov
588-
OV = compute_OV(model, centered=True, includes_eos=includes_eos)
589-
axis_limits["OVCentered_zmin"] = np.min(
590-
[axis_limits["OVCentered_zmin"], OV["data"].min()]
591-
)
592-
axis_limits["OVCentered_zmax"] = np.max(
593-
[axis_limits["OVCentered_zmax"], OV["data"].max()]
594-
)
663+
OV = cached_data[("OV", True)]
595664
fig_ov = imshow(
596665
OV["data"],
597666
title=OV["title"][title_kind],
@@ -609,16 +678,8 @@ def display_basic_interpretation(
609678
result["EVOU-centered"] = fig_ov
610679

611680
for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
612-
pos_QK = compute_QK_by_position(
613-
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
614-
)
681+
pos_QK = cached_data[("pos_QK", attn_scale)]
615682
if includes_eos:
616-
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
617-
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
618-
)
619-
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
620-
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
621-
)
622683
fig_qk = px.scatter(
623684
pos_QK["data"],
624685
title=pos_QK["title"][title_kind],
@@ -631,12 +692,6 @@ def display_basic_interpretation(
631692
if show:
632693
fig_qk.show(renderer=renderer)
633694
else:
634-
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
635-
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
636-
)
637-
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
638-
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
639-
)
640695
fig_qk = imshow(
641696
pos_QK["data"]["QK"],
642697
title=pos_QK["title"][title_kind],
@@ -653,16 +708,9 @@ def display_basic_interpretation(
653708
)
654709
result[f"EQKP{attn_scale}"] = fig_qk
655710

656-
irrelevant = compute_irrelevant(
657-
model,
658-
include_equals_OV=include_equals_OV,
659-
includes_eos=includes_eos,
660-
title_kind=title_kind,
661-
)
711+
irrelevant = cached_data["irrelevant"]
662712
for key, data in irrelevant["data"].items():
663713
if len(data.shape) == 2:
664-
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], data.min()])
665-
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], data.max()])
666714
fig = imshow(
667715
data,
668716
title=key,

notebooks_jason/max_of_K_all_models.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def optimize_pngs(errs: list[Exception] = []):
417417
EVOU_max_minus_diag_logit_diff,
418418
attention_difference_over_gap,
419419
display_basic_interpretation,
420+
compute_basic_interpretation_axis_limits,
420421
display_EQKE_SVD_analysis,
421422
hist_attention_difference_over_gap,
422423
hist_EVOU_max_minus_diag_logit_diff,
@@ -1912,13 +1913,67 @@ def handle_compute_EQKE_SVD_analysis(
19121913

19131914
# %% [markdown]
19141915
# # Plots
1916+
all_axis_limits = defaultdict(dict)
1917+
all_cached_data = {}
1918+
with tqdm(
1919+
runtime_models.items(), desc="compute_basic_interpretation_axis_limits"
1920+
) as pbar:
1921+
for seed, (_runtime, model) in pbar:
1922+
pbar.set_postfix(dict(seed=seed))
1923+
all_cached_data[seed], axis_limits = compute_basic_interpretation_axis_limits(
1924+
model,
1925+
include_uncentered=True,
1926+
plot_with=PLOT_WITH,
1927+
)
1928+
for k, v in axis_limits.items():
1929+
all_axis_limits[k][seed] = v
1930+
1931+
axis_limits = {}
1932+
for k, v in all_axis_limits.items():
1933+
if k.endswith("min"):
1934+
axis_limits[k] = np.min(list(v.values()))
1935+
elif k.endswith("max"):
1936+
axis_limits[k] = np.max(list(v.values()))
1937+
else:
1938+
raise ValueError(f"Unknown axis limit key: {k}")
1939+
1940+
for k in axis_limits.keys():
1941+
k_no_min_max = (
1942+
k.replace("zmin", "").replace("zmax", "").replace("min", "").replace("max", "")
1943+
)
1944+
latex_key = "".join(
1945+
[
1946+
kpart if kpart[:1] == kpart[:1].capitalize() else kpart.capitalize()
1947+
for kpart in k_no_min_max.replace("-", "_").split("_")
1948+
]
1949+
)
1950+
k_min = k.replace("max", "min")
1951+
k_max = k.replace("min", "max")
1952+
assert k_min in axis_limits, f"Missing {k_min}"
1953+
assert k_max in axis_limits, f"Missing {k_max}"
1954+
assert k_min == k or k_max == k, f"Unknown key: {k}"
1955+
assert k_min != k_max, f"Same key: {k}"
1956+
if "centered" not in k.lower():
1957+
v_max = np.max([np.abs(axis_limits[k_min]), np.abs(axis_limits[k_max])])
1958+
axis_limits[k_min] = -v_max
1959+
axis_limits[k_max] = v_max
1960+
assert "OV" in k or "QK" in k, f"Unknown key: {k}"
1961+
1962+
for k, v in axis_limits.items():
1963+
k = "".join(
1964+
[
1965+
kpart if kpart[0] == kpart[0].capitalize() else kpart.capitalize()
1966+
for kpart in k.replace("-", "_").split("_")
1967+
]
1968+
)
1969+
latex_values[f"AxisLimits{k}Float"] = v
1970+
19151971
# %%
19161972
if (DISPLAY_PLOTS or SAVE_PLOTS) and INDIVIDUAL_PLOTS:
1917-
all_axis_limits = defaultdict(dict)
19181973
with tqdm(runtime_models.items(), desc="display_basic_interpretation") as pbar:
19191974
for seed, (_runtime, model) in pbar:
19201975
pbar.set_postfix(dict(seed=seed))
1921-
figs, axis_limits = display_basic_interpretation(
1976+
figs, _ = display_basic_interpretation(
19221977
model,
19231978
include_uncentered=True,
19241979
OV_colorscale=default_OV_colorscale,
@@ -1928,9 +1983,8 @@ def handle_compute_EQKE_SVD_analysis(
19281983
plot_with=PLOT_WITH,
19291984
renderer=RENDERER,
19301985
show=DISPLAY_PLOTS,
1986+
cached_data=all_cached_data[seed],
19311987
)
1932-
for k, v in axis_limits.items():
1933-
all_axis_limits[k][seed] = v
19341988
for attn_scale in ("", "WithAttnScale"):
19351989
for fig in (
19361990
figs[f"EQKE{attn_scale}"],
@@ -1970,15 +2024,6 @@ def handle_compute_EQKE_SVD_analysis(
19702024
if unused_keys:
19712025
print(f"Unused keys: {unused_keys}")
19722026

1973-
axis_limits = {}
1974-
for k, v in all_axis_limits.items():
1975-
if k.endswith("min"):
1976-
axis_limits[k] = np.min(list(v.values()))
1977-
elif k.endswith("max"):
1978-
axis_limits[k] = np.max(list(v.values()))
1979-
else:
1980-
raise ValueError(f"Unknown axis limit key: {k}")
1981-
19822027
seen = set()
19832028
for k in axis_limits.keys():
19842029
k_no_min_max = (
@@ -2001,8 +2046,14 @@ def handle_compute_EQKE_SVD_analysis(
20012046
assert k_min != k_max, f"Same key: {k}"
20022047
if "centered" not in k.lower():
20032048
v_max = np.max([np.abs(axis_limits[k_min]), np.abs(axis_limits[k_max])])
2004-
axis_limits[k_min] = -v_max
2005-
axis_limits[k_max] = v_max
2049+
if axis_limits[k_min] != -v_max:
2050+
print(
2051+
f"Warning: {axis_limits[k_min]} == axix_limits[{k_min}] != -v_max == {-v_max}"
2052+
)
2053+
if axis_limits[k_max] != v_max:
2054+
print(
2055+
f"Warning: {axis_limits[k_max]} == axix_limits[{k_max}] != v_max == {v_max}"
2056+
)
20062057

20072058
assert "OV" in k or "QK" in k, f"Unknown key: {k}"
20082059
if k_no_min_max in seen:
@@ -2024,15 +2075,6 @@ def handle_compute_EQKE_SVD_analysis(
20242075
latex_figures[f"Colorbar-{latex_key}-Vertical"] = figV
20252076
latex_figures[f"Colorbar-{latex_key}-Horizontal"] = figH
20262077

2027-
for k, v in axis_limits.items():
2028-
k = "".join(
2029-
[
2030-
kpart if kpart[0] == kpart[0].capitalize() else kpart.capitalize()
2031-
for kpart in k.replace("-", "_").split("_")
2032-
]
2033-
)
2034-
latex_values[f"AxisLimits{k}Float"] = v
2035-
20362078
with tqdm(
20372079
runtime_models.items(), desc="display_basic_interpretation (uniform limits)"
20382080
) as pbar:
@@ -2045,6 +2087,7 @@ def handle_compute_EQKE_SVD_analysis(
20452087
QK_colorscale=default_QK_colorscale,
20462088
QK_SVD_colorscale=default_QK_SVD_colorscale,
20472089
tok_dtick=10,
2090+
cached_data=all_cached_data[seed],
20482091
**axis_limits,
20492092
plot_with=PLOT_WITH,
20502093
renderer=RENDERER,

0 commit comments

Comments
 (0)