Skip to content

Commit edbda5a

Browse files
committed
Factor out plot_matrices_from_model
1 parent dc8c213 commit edbda5a

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

gbmi/training_tools/logging.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,17 +1088,15 @@ def plot_matrices(
10881088
return figs
10891089

10901090
@torch.no_grad()
1091-
def log_matrices(
1091+
def plot_matrices_from_model(
10921092
self,
1093-
logger: Run,
10941093
model: HookedTransformer,
10951094
*,
10961095
unsafe: bool = False,
10971096
default_heatmap_kwargs: dict[str, Any] = {},
1098-
**kwargs,
10991097
):
11001098
self.assert_model_supported(model, unsafe=unsafe)
1101-
figs = self.plot_matrices(
1099+
return self.plot_matrices(
11021100
model.W_E,
11031101
model.W_pos,
11041102
model.W_U,
@@ -1114,6 +1112,20 @@ def log_matrices(
11141112
attention_dir=model.cfg.attention_dir,
11151113
default_heatmap_kwargs=default_heatmap_kwargs,
11161114
)
1115+
1116+
@torch.no_grad()
1117+
def log_matrices(
1118+
self,
1119+
logger: Run,
1120+
model: HookedTransformer,
1121+
*,
1122+
unsafe: bool = False,
1123+
default_heatmap_kwargs: dict[str, Any] = {},
1124+
**kwargs,
1125+
):
1126+
figs = self.plot_matrices_from_model(
1127+
model, unsafe=unsafe, default_heatmap_kwargs=default_heatmap_kwargs
1128+
)
11171129
logger.log(
11181130
{encode_4_byte_unicode(k): v for k, v in figs.items()},
11191131
commit=False,

0 commit comments

Comments
 (0)