Skip to content

Commit dc8c213

Browse files
committed
Factor out plot matrices
1 parent 0b673b2 commit dc8c213

File tree

1 file changed

+61
-7
lines changed

1 file changed

+61
-7
lines changed

gbmi/training_tools/logging.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,16 +1009,41 @@ def matrices_to_log(
10091009
)
10101010

10111011
@torch.no_grad()
1012-
def log_matrices(
1012+
def plot_matrices(
10131013
self,
1014-
logger: Run,
1015-
model: HookedTransformer,
1014+
W_E: Float[Tensor, "d_vocab d_model"], # noqa: F722
1015+
W_pos: Float[Tensor, "n_ctx d_model"], # noqa: F722
1016+
W_U: Float[Tensor, "d_model d_vocab_out"], # noqa: F722
1017+
W_Q: Float[Tensor, "n_layers n_heads d_model d_head"], # noqa: F722
1018+
W_K: Float[Tensor, "n_layers n_heads d_model d_head"], # noqa: F722
1019+
W_V: Float[Tensor, "n_layers n_heads d_model d_head"], # noqa: F722
1020+
W_O: Float[Tensor, "n_layers n_heads d_head d_model"], # noqa: F722# noqa: F722
1021+
b_U: Optional[Float[Tensor, "d_vocab_out"]] = None, # noqa: F821
1022+
b_Q: Optional[Float[Tensor, "n_layers n_heads d_head"]] = None, # noqa: F722
1023+
b_K: Optional[Float[Tensor, "n_layers n_heads d_head"]] = None, # noqa: F722
1024+
b_V: Optional[Float[Tensor, "n_layers n_heads d_head"]] = None, # noqa: F722
1025+
b_O: Optional[Float[Tensor, "n_layers d_model"]] = None, # noqa: F722
10161026
*,
1017-
unsafe: bool = False,
1027+
attention_dir: Literal["bidirectional", "causal"] = "causal",
10181028
default_heatmap_kwargs: dict[str, Any] = {},
1019-
**kwargs,
1020-
):
1021-
matrices = dict(self.matrices_to_log(model, unsafe=unsafe))
1029+
) -> dict[str, go.Figure]:
1030+
matrices = dict(
1031+
self.matrices_to_plot(
1032+
W_E=W_E,
1033+
W_pos=W_pos,
1034+
W_U=W_U,
1035+
W_Q=W_Q,
1036+
W_K=W_K,
1037+
W_V=W_V,
1038+
W_O=W_O,
1039+
b_U=b_U,
1040+
b_Q=b_Q,
1041+
b_K=b_K,
1042+
b_V=b_V,
1043+
b_O=b_O,
1044+
attention_dir=attention_dir,
1045+
)
1046+
)
10221047
if self.use_subplots:
10231048
OVs = tuple(name for name, _ in matrices.items() if "U" in name)
10241049
QKs = tuple(name for name, _ in matrices.items() if "U" not in name)
@@ -1060,6 +1085,35 @@ def log_matrices(
10601085
)
10611086
for name, matrix in matrices.items()
10621087
}
1088+
return figs
1089+
1090+
@torch.no_grad()
1091+
def log_matrices(
1092+
self,
1093+
logger: Run,
1094+
model: HookedTransformer,
1095+
*,
1096+
unsafe: bool = False,
1097+
default_heatmap_kwargs: dict[str, Any] = {},
1098+
**kwargs,
1099+
):
1100+
self.assert_model_supported(model, unsafe=unsafe)
1101+
figs = self.plot_matrices(
1102+
model.W_E,
1103+
model.W_pos,
1104+
model.W_U,
1105+
model.W_Q,
1106+
model.W_K,
1107+
model.W_V,
1108+
model.W_O,
1109+
model.b_U,
1110+
model.b_Q,
1111+
model.b_K,
1112+
model.b_V,
1113+
model.b_O,
1114+
attention_dir=model.cfg.attention_dir,
1115+
default_heatmap_kwargs=default_heatmap_kwargs,
1116+
)
10631117
logger.log(
10641118
{encode_4_byte_unicode(k): v for k, v in figs.items()},
10651119
commit=False,

0 commit comments

Comments
 (0)