File tree Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -1088,17 +1088,15 @@ def plot_matrices(
1088
1088
return figs
1089
1089
1090
1090
@torch .no_grad ()
1091
- def log_matrices (
1091
+ def plot_matrices_from_model (
1092
1092
self ,
1093
- logger : Run ,
1094
1093
model : HookedTransformer ,
1095
1094
* ,
1096
1095
unsafe : bool = False ,
1097
1096
default_heatmap_kwargs : dict [str , Any ] = {},
1098
- ** kwargs ,
1099
1097
):
1100
1098
self .assert_model_supported (model , unsafe = unsafe )
1101
- figs = self .plot_matrices (
1099
+ return self .plot_matrices (
1102
1100
model .W_E ,
1103
1101
model .W_pos ,
1104
1102
model .W_U ,
@@ -1114,6 +1112,20 @@ def log_matrices(
1114
1112
attention_dir = model .cfg .attention_dir ,
1115
1113
default_heatmap_kwargs = default_heatmap_kwargs ,
1116
1114
)
1115
+
1116
+ @torch .no_grad ()
1117
+ def log_matrices (
1118
+ self ,
1119
+ logger : Run ,
1120
+ model : HookedTransformer ,
1121
+ * ,
1122
+ unsafe : bool = False ,
1123
+ default_heatmap_kwargs : dict [str , Any ] = {},
1124
+ ** kwargs ,
1125
+ ):
1126
+ figs = self .plot_matrices_from_model (
1127
+ model , unsafe = unsafe , default_heatmap_kwargs = default_heatmap_kwargs
1128
+ )
1117
1129
logger .log (
1118
1130
{encode_4_byte_unicode (k ): v for k , v in figs .items ()},
1119
1131
commit = False ,
You can’t perform that action at this time.
0 commit comments